diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d1c846c04827..6d46c3190626 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -59,3 +59,4 @@ Collate: 'window.R' RoxygenNote: 5.0.1 VignetteBuilder: knitr +NeedsCompilation: no diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 57838f52eac3..c51eb0f39c4b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -76,7 +76,9 @@ exportMethods("glm", export("setJobGroup", "clearJobGroup", "cancelJobGroup", - "setJobDescription") + "setJobDescription", + "setLocalProperty", + "getLocalProperty") # Export Utility methods export("setLogLevel") @@ -133,6 +135,7 @@ exportMethods("arrange", "isStreaming", "join", "limit", + "localCheckpoint", "merge", "mutate", "na.omit", @@ -176,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", @@ -225,11 +229,14 @@ exportMethods("%<=>%", "crc32", "create_array", "create_map", + "current_date", + "current_timestamp", "hash", "cume_dist", "date_add", "date_format", "date_sub", + "date_trunc", "datediff", "dayofmonth", "dayofweek", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b8d732a48586..9956f7eda91e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... additional sorting fields #' @param decreasing a logical argument indicating sorting order for columns when #' a character vector is specified for col +#' @param withinPartitions a logical argument indicating whether to sort only within each partition #' @return A SparkDataFrame where all elements are sorted. #' @family SparkDataFrame functions #' @aliases arrange,SparkDataFrame,Column-method @@ -2312,16 +2313,21 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) #' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) +#' arrange(df, "col1", "col2", withinPartitions = TRUE) #' } #' @note arrange(SparkDataFrame, Column) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "Column"), - function(x, col, ...) { + function(x, col, ..., withinPartitions = FALSE) { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", jcols) + if (withinPartitions) { + sdf <- callJMethod(x@sdf, "sortWithinPartitions", jcols) + } else { + sdf <- callJMethod(x@sdf, "sort", jcols) + } dataFrame(sdf) }) @@ -2332,7 +2338,7 @@ setMethod("arrange", #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), - function(x, col, ..., decreasing = FALSE) { + function(x, col, ..., decreasing = FALSE, withinPartitions = FALSE) { # all sorting columns by <- list(col, ...) @@ -2356,7 +2362,7 @@ setMethod("arrange", } }) - do.call("arrange", c(x, jcols)) + do.call("arrange", c(x, jcols, withinPartitions = withinPartitions)) }) #' @rdname arrange @@ -3655,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3701,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3719,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3731,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3742,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3782,6 +3832,33 @@ setMethod("checkpoint", dataFrame(df) }) +#' localCheckpoint +#' +#' Returns a locally checkpointed version of this SparkDataFrame. Checkpointing can be used to +#' truncate the logical plan, which is especially useful in iterative algorithms where the plan +#' may grow exponentially. Local checkpoints are stored in the executors using the caching +#' subsystem and therefore they are not reliable. +#' +#' @param x A SparkDataFrame +#' @param eager whether to locally checkpoint this SparkDataFrame immediately +#' @return a new locally checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases localCheckpoint,SparkDataFrame-method +#' @rdname localCheckpoint +#' @name localCheckpoint +#' @export +#' @examples +#'\dontrun{ +#' df <- localCheckpoint(df) +#' } +#' @note localCheckpoint since 2.3.0 +setMethod("localCheckpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "localCheckpoint", as.logical(eager)) + dataFrame(df) + }) + #' cube #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. @@ -3934,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb..9d0a2d5e074e 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 237ef061e807..55365a41d774 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -39,11 +39,19 @@ NULL #' Date time functions defined for \code{Column}. #' #' @param x Column to compute on. In \code{window}, it must be a time Column of -#' \code{TimestampType}. -#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string -#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for -#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' \code{TimestampType}. This is not used with \code{current_date} and +#' \code{current_timestamp} +#' @param format The format for the given dates or timestamps in Column \code{x}. See the +#' format used in the following methods: +#' \itemize{ +#' \item \code{to_date} and \code{to_timestamp}: it is the string to use to parse +#' Column \code{x} to DateType or TimestampType. +#' \item \code{trunc}: it is the string to use to specify the truncation method. +#' For example, "year", "yyyy", "yy" for truncate by year, or "month", "mon", +#' "mm" for truncate by month. +#' \item \code{date_trunc}: it is similar with \code{trunc}'s but additionally +#' supports "day", "dd", "second", "minute", "hour", "week" and "quarter". +#' } #' @param ... additional argument(s). #' @name column_datetime_functions #' @rdname column_datetime_functions @@ -1102,10 +1110,11 @@ setMethod("lower", }) #' @details -#' \code{ltrim}: Trims the spaces from left end for the specified string value. +#' \code{ltrim}: Trims the spaces from left end for the specified string value. Optionally a +#' \code{trimString} can be specified. #' #' @rdname column_string_functions -#' @aliases ltrim ltrim,Column-method +#' @aliases ltrim ltrim,Column,missing-method #' @export #' @examples #' @@ -1121,12 +1130,24 @@ setMethod("lower", #' head(tmp)} #' @note ltrim since 1.5.0 setMethod("ltrim", - signature(x = "Column"), - function(x) { + signature(x = "Column", trimString = "missing"), + function(x, trimString) { jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc) column(jc) }) +#' @param trimString a character string to trim with +#' @rdname column_string_functions +#' @aliases ltrim,Column,character-method +#' @export +#' @note ltrim(Column, character) since 2.3.0 +setMethod("ltrim", + signature(x = "Column", trimString = "character"), + function(x, trimString) { + jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc, trimString) + column(jc) + }) + #' @details #' \code{max}: Returns the maximum value of the expression in a group. #' @@ -1341,19 +1362,31 @@ setMethod("bround", }) #' @details -#' \code{rtrim}: Trims the spaces from right end for the specified string value. +#' \code{rtrim}: Trims the spaces from right end for the specified string value. Optionally a +#' \code{trimString} can be specified. #' #' @rdname column_string_functions -#' @aliases rtrim rtrim,Column-method +#' @aliases rtrim rtrim,Column,missing-method #' @export #' @note rtrim since 1.5.0 setMethod("rtrim", - signature(x = "Column"), - function(x) { + signature(x = "Column", trimString = "missing"), + function(x, trimString) { jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc) column(jc) }) +#' @rdname column_string_functions +#' @aliases rtrim,Column,character-method +#' @export +#' @note rtrim(Column, character) since 2.3.0 +setMethod("rtrim", + signature(x = "Column", trimString = "character"), + function(x, trimString) { + jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc, trimString) + column(jc) + }) + #' @details #' \code{sd}: Alias for \code{stddev_samp}. #' @@ -1782,19 +1815,31 @@ setMethod("to_timestamp", }) #' @details -#' \code{trim}: Trims the spaces from both ends for the specified string column. +#' \code{trim}: Trims the spaces from both ends for the specified string column. Optionally a +#' \code{trimString} can be specified. #' #' @rdname column_string_functions -#' @aliases trim trim,Column-method +#' @aliases trim trim,Column,missing-method #' @export #' @note trim since 1.5.0 setMethod("trim", - signature(x = "Column"), - function(x) { + signature(x = "Column", trimString = "missing"), + function(x, trimString) { jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc) column(jc) }) +#' @rdname column_string_functions +#' @aliases trim,Column,character-method +#' @export +#' @note trim(Column, character) since 2.3.0 +setMethod("trim", + signature(x = "Column", trimString = "character"), + function(x, trimString) { + jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc, trimString) + column(jc) + }) + #' @details #' \code{unbase64}: Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. @@ -2088,7 +2133,8 @@ setMethod("countDistinct", }) #' @details -#' \code{concat}: Concatenates multiple input string columns together into a single string column. +#' \code{concat}: Concatenates multiple input columns together into a single column. +#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. #' #' @rdname column_string_functions #' @aliases concat concat,Column-method @@ -2770,11 +2816,11 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), }) #' @details -#' \code{substring_index}: Returns the substring from string str before count occurrences of -#' the delimiter delim. If count is positive, everything the left of the final delimiter -#' (counting from left) is returned. If count is negative, every to the right of the final -#' delimiter (counting from the right) is returned. substring_index performs a case-sensitive -#' match when searching for delim. +#' \code{substring_index}: Returns the substring from string (\code{x}) before \code{count} +#' occurrences of the delimiter (\code{delim}). If \code{count} is positive, everything the left of +#' the final delimiter (counting from left) is returned. If \code{count} is negative, every to the +#' right of the final delimiter (counting from the right) is returned. \code{substring_index} +#' performs a case-sensitive match when searching for the delimiter. #' #' @param delim a delimiter string. #' @param count number of occurrences of \code{delim} before the substring is returned. @@ -3478,3 +3524,53 @@ setMethod("trunc", x@jc, as.character(format)) column(jc) }) + +#' @details +#' \code{date_trunc}: Returns timestamp truncated to the unit specified by the format. +#' +#' @rdname column_datetime_functions +#' @aliases date_trunc date_trunc,character,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, date_trunc("hour", df$time), date_trunc("minute", df$time), +#' date_trunc("week", df$time), date_trunc("quarter", df$time)))} +#' @note date_trunc since 2.3.0 +setMethod("date_trunc", + signature(format = "character", x = "Column"), + function(format, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_trunc", format, x@jc) + column(jc) + }) + +#' @details +#' \code{current_date}: Returns the current date as a date column. +#' +#' @rdname column_datetime_functions +#' @aliases current_date current_date,missing-method +#' @export +#' @examples +#' \dontrun{ +#' head(select(df, current_date(), current_timestamp()))} +#' @note current_date since 2.3.0 +setMethod("current_date", + signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "current_date") + column(jc) + }) + +#' @details +#' \code{current_timestamp}: Returns the current timestamp as a timestamp column. +#' +#' @rdname column_datetime_functions +#' @aliases current_timestamp current_timestamp,missing-method +#' @export +#' @note current_timestamp since 2.3.0 +setMethod("current_timestamp", + signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "current_timestamp") + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8fcf269087c7..e0dde3339fab 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -611,6 +611,10 @@ setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' @rdname localCheckpoint +#' @export +setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) + #' @rdname merge #' @export setGeneric("merge") @@ -795,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) @@ -1023,6 +1033,17 @@ setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) + +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) + + #' @rdname column_datetime_diff_functions #' @export #' @name NULL @@ -1043,6 +1064,11 @@ setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) + #' @rdname column_datetime_functions #' @export #' @name NULL @@ -1221,7 +1247,7 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions #' @export #' @name NULL -setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) +setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions #' @export @@ -1371,7 +1397,7 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions #' @export #' @name NULL -setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) +setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions #' @export @@ -1511,7 +1537,7 @@ setGeneric("translate", function(x, matchingString, replaceString) { standardGen #' @rdname column_string_functions #' @export #' @name NULL -setGeneric("trim", function(x) { standardGeneric("trim") }) +setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions #' @export diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index fb5f1d21fc72..965471f3b07a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -560,10 +560,55 @@ cancelJobGroup <- function(sc, groupId) { #'} #' @note setJobDescription since 2.3.0 setJobDescription <- function(value) { + if (!is.null(value)) { + value <- as.character(value) + } sc <- getSparkContext() invisible(callJMethod(sc, "setJobDescription", value)) } +#' Set a local property that affects jobs submitted from this thread, such as the +#' Spark fair scheduler pool. +#' +#' @param key The key for a local property. +#' @param value The value for a local property. +#' @rdname setLocalProperty +#' @name setLocalProperty +#' @examples +#'\dontrun{ +#' setLocalProperty("spark.scheduler.pool", "poolA") +#'} +#' @note setLocalProperty since 2.3.0 +setLocalProperty <- function(key, value) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + if (!is.null(value)) { + value <- as.character(value) + } + sc <- getSparkContext() + invisible(callJMethod(sc, "setLocalProperty", as.character(key), value)) +} + +#' Get a local property set in this thread, or \code{NULL} if it is missing. See +#' \code{setLocalProperty}. +#' +#' @param key The key for a local property. +#' @rdname getLocalProperty +#' @name getLocalProperty +#' @examples +#'\dontrun{ +#' getLocalProperty("spark.scheduler.pool") +#'} +#' @note getLocalProperty since 2.3.0 +getLocalProperty <- function(key) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + sc <- getSparkContext() + callJMethod(sc, "getLocalProperty", as.character(key)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R index b2ec6c67311d..209827d9fdc2 100644 --- a/R/pkg/tests/fulltests/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 77635c5a256b..f0d0a5114f89 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,7 +100,6 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() - setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) @@ -108,6 +107,38 @@ test_that("job group functions can be called", { sparkR.session.stop() }) +test_that("job description and local properties can be set and got", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + setJobDescription("job description") + expect_equal(getLocalProperty("spark.job.description"), "job description") + setJobDescription(1234) + expect_equal(getLocalProperty("spark.job.description"), "1234") + setJobDescription(NULL) + expect_equal(getLocalProperty("spark.job.description"), NULL) + setJobDescription(NA) + expect_equal(getLocalProperty("spark.job.description"), NULL) + + setLocalProperty("spark.scheduler.pool", "poolA") + expect_equal(getLocalProperty("spark.scheduler.pool"), "poolA") + setLocalProperty("spark.scheduler.pool", NULL) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + setLocalProperty("spark.scheduler.pool", NA) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + + setLocalProperty(4321, 1234) + expect_equal(getLocalProperty(4321), "1234") + setLocalProperty(4321, NULL) + expect_equal(getLocalProperty(4321), NULL) + setLocalProperty(4321, NA) + expect_equal(getLocalProperty(4321), NULL) + + expect_error(setLocalProperty(NULL, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NULL), "key should not be NULL or NA") + expect_error(setLocalProperty(NA, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NA), "key should not be NULL or NA") + sparkR.session.stop() +}) + test_that("utility function can be called", { sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index d87f5d270573..5197838eaac6 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -957,6 +957,28 @@ test_that("setCheckpointDir(), checkpoint() on a DataFrame", { } }) +test_that("localCheckpoint() on a DataFrame", { + if (windows_with_hadoop()) { + # Checkpoint directory shouldn't matter in localCheckpoint. + checkpointDir <- file.path(tempdir(), "lcproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + setCheckpointDir(checkpointDir) + + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + writeLines(mockLines, textPath) + # Read it lazily and then locally checkpoint eagerly. + df <- read.df(textPath, "text") + df <- localCheckpoint(df, eager = TRUE) + # Here, we remove the source path to check eagerness. + unlink(textPath) + expect_is(df, "SparkDataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + } +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) @@ -1405,7 +1427,7 @@ test_that("column functions", { c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) - c12 <- variance(c) + c12 <- variance(c) + ltrim(c, "a") + rtrim(c, "b") + trim(c, "c") c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() @@ -1418,6 +1440,8 @@ test_that("column functions", { c22 <- not(c) c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") + c24 <- date_trunc("hour", c) + date_trunc("minute", c) + date_trunc("week", c) + + date_trunc("quarter", c) + current_date() + current_timestamp() # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1729,6 +1753,7 @@ test_that("date functions on a DataFrame", { expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + expect_equal(collect(select(df2, month(date_trunc("yyyy", df2$b))))[, 1], c(1, 1)) l3 <- list(list(a = 1000), list(a = -1000)) df3 <- createDataFrame(l3) @@ -2105,6 +2130,11 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted7 <- arrange(df, "name", decreasing = FALSE) expect_equal(collect(sorted7)[2, "age"], 19) + + df <- createDataFrame(cars, numPartitions = 10) + expect_equal(getNumPartitions(df), 10) + sorted8 <- arrange(df, "dist", withinPartitions = TRUE) + expect_equal(collect(sorted8)[5:6, "dist"], c(22, 10)) }) test_that("filter() on a DataFrame", { diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f51..a354d50c6b54 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 63812ba70bb5..94d75188fb94 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -27,7 +27,10 @@ if (.Platform$OS.type == "windows") { # Setup global test environment # Install Spark first to set SPARK_HOME -install.spark() + +# NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on +# CRAN machines. For Jenkins we should already have SPARK_HOME set. +install.spark(overwrite = TRUE) sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 8c4ea2f2db18..2e662424b25f 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -391,8 +391,7 @@ We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataF ```{r} carsSubDF <- select(carsDF, "model", "mpg") -schema <- structType(structField("model", "string"), structField("mpg", "double"), - structField("kmpg", "double")) +schema <- "model STRING, mpg DOUBLE, kmpg DOUBLE" out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` diff --git a/bin/find-spark-home b/bin/find-spark-home index fa78407d4175..617dbaa4fff8 100755 --- a/bin/find-spark-home +++ b/bin/find-spark-home @@ -21,7 +21,7 @@ FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" -# Short cirtuit if the user already has this set. +# Short circuit if the user already has this set. if [ ! -z "${SPARK_HOME}" ]; then exit 0 elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index b3ba76ba5805..f62e85d43531 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -86,7 +86,7 @@ class LevelDBIterator implements KVStoreIterator { end = index.start(parent, params.last); } if (it.hasNext()) { - // When descending, the caller may have set up the start of iteration at a non-existant + // When descending, the caller may have set up the start of iteration at a non-existent // entry that is guaranteed to be after the desired entry. For example, if you have a // compound key (a, b) where b is a, integer, you may seek to the end of the elements that // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index ea9b3ce4e352..8b8f9892847c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -18,12 +18,12 @@ package org.apache.spark.network.buffer; import java.io.File; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; @@ -94,9 +94,9 @@ public ByteBuffer nioByteBuffer() throws IOException { @Override public InputStream createInputStream() throws IOException { - InputStream is = null; + FileInputStream is = null; try { - is = Files.newInputStream(file.toPath()); + is = new FileInputStream(file); ByteStreams.skipFully(is, offset); return new LimitedInputStream(is, length); } catch (IOException e) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 897d0f9e4fb8..a5337656cbd8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -47,7 +47,7 @@ class MessageWithHeader extends AbstractFileRegion { /** * When the write buffer size is larger than this limit, I/O will be done in chunks of this size. * The size should not be too large as it will waste underlying memory copy. e.g. If network - * avaliable buffer is smaller than this limit, the data cannot be sent within one single write + * available buffer is smaller than this limit, the data cannot be sent within one single write * operation while it still will make memory copy with this size. */ private static final int NIO_BUFFER_LIMIT = 256 * 1024; @@ -100,7 +100,7 @@ public long transferred() { * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. * * The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transfered()). + * of bytes transferred so far (i.e. value returned by transferred()). */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 16ab4efcd4f5..3ac9081d78a7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -38,7 +38,7 @@ import org.apache.spark.network.util.NettyUtils; /** - * Provides SASL-based encription for transport channels. The single method exposed by this + * Provides SASL-based encryption for transport channels. The single method exposed by this * class installs the needed channel handlers on a connected channel. */ class SaslEncryption { @@ -166,7 +166,7 @@ static class EncryptedMessage extends AbstractFileRegion { * This makes assumptions about how netty treats FileRegion instances, because there's no way * to know beforehand what will be the size of the encrypted message. Namely, it assumes * that netty will try to transfer data from this message while - * transfered() < count(). So these two methods return, technically, wrong data, + * transferred() < count(). So these two methods return, technically, wrong data, * but netty doesn't know better. */ @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 50d9651ccbbb..8e73ab077a5c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -29,7 +29,7 @@ /** * A customized frame decoder that allows intercepting raw data. *

- * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * This behaves like Netty's frame decoder (with hard coded parameters that match this library's * needs), except it allows an interceptor to be installed to read data directly before it's * framed. *

diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 3f2f20b4149f..9cac7d00cc6b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -18,11 +18,11 @@ package org.apache.spark.network.shuffle; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; -import java.nio.file.Files; import java.util.Arrays; import org.slf4j.Logger; @@ -165,7 +165,7 @@ private class DownloadCallback implements StreamCallback { DownloadCallback(int chunkIndex) throws IOException { this.targetFile = tempFileManager.createTempFile(); - this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 23438a08fa09..6d201b8fe8d7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -127,7 +127,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); assertEquals(parsedShuffleInfo, shuffleInfo); - // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // Intentionally keep these hard-coded strings in here, to check backwards-compatibility. // its not legacy yet, but keeping this here in case anybody changes it String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c0b425e72959..37803c7a3b10 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -34,7 +34,7 @@ *

  • {@link String}
  • * * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that has * not actually been put in the {@code BloomFilter}. * * The implementation is largely based on the {@code BloomFilter} class from Guava. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index f121b1cd745b..a6b1f7a16d60 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -66,7 +66,7 @@ public static boolean arrayEquals( i += 1; } } - // for architectures that suport unaligned accesses, chew it up 8 bytes at a time + // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { while (i <= length - 8) { if (Platform.getLong(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 7ced13d35723..c03caf0076f6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -74,4 +74,29 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { } return Arrays.copyOfRange(bytes, start, end); } + + public static byte[] concat(byte[]... inputs) { + // Compute the total length of the result + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].length; + } else { + return null; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].length; + Platform.copyMemory( + inputs[i], Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return result; + } } diff --git a/core/pom.xml b/core/pom.xml index fa138d3e7a4e..0a5bd958fc9c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -351,6 +351,14 @@ spark-tags_${scala.binary.version} + + org.apache.spark + spark-launcher_${scala.binary.version} + ${project.version} + tests + test + + 0.10.2 - 4.5.2 - 4.4.4 + 4.5.4 + 4.4.8 3.1 3.4.1 @@ -2295,6 +2295,9 @@ org.apache.maven.plugins maven-assembly-plugin 3.1.0 + + posix + org.apache.maven.plugins diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813e..3b452f35c5ec 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/python/README.md b/python/README.md index 84ec88141cb0..3f17fdb98a08 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). +At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a6767cee9bf2..d4470b5bf290 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -18,13 +18,52 @@ from abc import ABCMeta, abstractmethod import copy +import threading from pyspark import since -from pyspark.ml.param import Params from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc from pyspark.sql.functions import udf -from pyspark.sql.types import StructField, StructType, DoubleType +from pyspark.sql.types import StructField, StructType + + +class _FitMultipleIterator(object): + """ + Used by default implementation of Estimator.fitMultiple to produce models in a thread safe + iterator. This class handles the simple case of fitMultiple where each param map should be + fit independently. + + :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset. + `fitSingleModel` may be called up to `numModels` times, with a unique index each time. + Each call to `fitSingleModel` with an index should return the Model associated with + that index. + :param numModel: Number of models this iterator should produce. + + See Estimator.fitMultiple for more info. + """ + def __init__(self, fitSingleModel, numModels): + """ + + """ + self.fitSingleModel = fitSingleModel + self.numModel = numModels + self.counter = 0 + self.lock = threading.Lock() + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + index = self.counter + if index >= self.numModel: + raise StopIteration("No models remaining.") + self.counter += 1 + return index, self.fitSingleModel(index) + + def next(self): + """For python2 compatibility.""" + return self.__next__() @inherit_doc @@ -47,6 +86,27 @@ def _fit(self, dataset): """ raise NotImplementedError() + @since("2.3.0") + def fitMultiple(self, dataset, paramMaps): + """ + Fits a model to the input dataset for each param map in `paramMaps`. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`. + :param paramMaps: A Sequence of param maps. + :return: A thread safe iterable which contains one model for each param map. Each + call to `next(modelIterator)` will return `(index, model)` where model was fit + using `paramMaps[index]`. `index` values may not be sequential. + + .. note:: DeveloperApi + .. note:: Experimental + """ + estimator = self.copy() + + def fitSingleModel(index): + return estimator.fit(dataset, paramMaps[index]) + + return _FitMultipleIterator(fitSingleModel, len(paramMaps)) + @since("1.3.0") def fit(self, dataset, params=None): """ @@ -61,7 +121,10 @@ def fit(self, dataset, params=None): if params is None: params = dict() if isinstance(params, (list, tuple)): - return [self.fit(dataset, paramMap) for paramMap in params] + models = [None] * len(params) + for index, model in self.fitMultiple(dataset, params): + models[index] = model + return models elif isinstance(params, dict): if params: return self.copy(params)._fit(dataset) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 608f2a571549..13bf95cce40b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -57,6 +57,7 @@ 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorIndexerModel', + 'VectorSizeHint', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @@ -713,9 +714,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, * Numeric columns: For numeric features, the hash value of the column name is used to map the - feature value to its index in the feature vector. Numeric features are never - treated as categorical, even when they are integers. You must explicitly - convert numeric columns containing categorical features to strings first. + feature value to its index in the feature vector. By default, numeric features + are not treated as categorical (even when they are integers). To treat them + as categorical, specify the relevant columns in `categoricalCols`. * String columns: For categorical features, the hash value of the string "column_name=value" @@ -740,6 +741,8 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + >>> hasher.setCategoricalCols(["real"]).transform(df).head().features + SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) @@ -751,10 +754,14 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, .. versionadded:: 2.3.0 """ + categoricalCols = Param(Params._dummy(), "categoricalCols", + "numeric columns to treat as categorical", + typeConverter=TypeConverters.toListString) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) """ super(FeatureHasher, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid) @@ -764,14 +771,28 @@ def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): @keyword_only @since("2.3.0") - def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) Sets params for this FeatureHasher. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.3.0") + def setCategoricalCols(self, value): + """ + Sets the value of :py:attr:`categoricalCols`. + """ + return self._set(categoricalCols=value) + + @since("2.3.0") + def getCategoricalCols(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.categoricalCols) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, @@ -3466,6 +3487,84 @@ def selectedFeatures(self): return self._call_java("selectedFeatures") +@inherit_doc +class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + A feature transformer that adds size information to the metadata of a vector column. + VectorAssembler needs size information for its input columns and cannot be used on streaming + dataframes without this metadata. + + .. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an + outputCol. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml import Pipeline, PipelineModel + >>> data = [(Vectors.dense([1., 2., 3.]), 4.)] + >>> df = spark.createDataFrame(data, ["vector", "float"]) + >>> + >>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip") + >>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled") + >>> pipeline = Pipeline(stages=[sizeHint, vecAssembler]) + >>> + >>> pipelineModel = pipeline.fit(df) + >>> pipelineModel.transform(df).head().assembled + DenseVector([1.0, 2.0, 3.0, 4.0]) + >>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline" + >>> pipelineModel.save(vectorSizeHintPath) + >>> loadedPipeline = PipelineModel.load(vectorSizeHintPath) + >>> loaded = loadedPipeline.transform(df).head().assembled + >>> expected = pipelineModel.transform(df).head().assembled + >>> loaded == expected + True + + .. versionadded:: 2.3.0 + """ + + size = Param(Params._dummy(), "size", "Size of vectors in column.", + typeConverter=TypeConverters.toInt) + + handleInvalid = Param(Params._dummy(), "handleInvalid", + "How to handle invalid vectors in inputCol. Invalid vectors include " + "nulls and vectors with the wrong size. The options are `skip` (filter " + "out rows with invalid vectors), `error` (throw an error) and " + "`optimistic` (do not check the vector size, and keep all rows). " + "`error` by default.", + TypeConverters.toString) + + @keyword_only + def __init__(self, inputCol=None, size=None, handleInvalid="error"): + """ + __init__(self, inputCol=None, size=None, handleInvalid="error") + """ + super(VectorSizeHint, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid) + self._setDefault(handleInvalid="error") + self.setParams(**self._input_kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCol=None, size=None, handleInvalid="error"): + """ + setParams(self, inputCol=None, size=None, handleInvalid="error") + Sets params for this VectorSizeHint. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def getSize(self): + """ Gets size param, the size of vectors in `inputCol`.""" + self.getOrDefault(self.size) + + @since("2.3.0") + def setSize(self, value): + """ Sets size param, the size of vectors in `inputCol`.""" + self._set(size=value) + + if __name__ == "__main__": import doctest import tempfile diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 384599dc0c53..c9b840276f67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -212,7 +212,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, ImageSchema = _ImageSchema() -# Monkey patch to disallow instantization of this class. +# Monkey patch to disallow instantiation of this class. def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index afcb0881c4dc..1af2b91da900 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2380,6 +2380,21 @@ def test_unary_transformer_transform(self): self.assertEqual(res.input + shiftVal, res.output) +class EstimatorTest(unittest.TestCase): + + def testDefaultFitMultiple(self): + N = 4 + data = MockDataset() + estimator = MockEstimator() + params = [{estimator.fake: i} for i in range(N)] + modelIter = estimator.fitMultiple(data, params) + indexList = [] + for index, model in modelIter: + self.assertEqual(model.getFake(), index) + indexList.append(index) + self.assertEqual(sorted(indexList), list(range(N))) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 47351133524e..6c0cad6cbaaa 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -31,6 +31,28 @@ 'TrainValidationSplitModel'] +def _parallelFitTasks(est, train, eva, validation, epm): + """ + Creates a list of callables which can be called from different threads to fit and evaluate + an estimator in parallel. Each callable returns an `(index, metric)` pair. + + :param est: Estimator, the estimator to be fit. + :param train: DataFrame, training data set, used for fitting. + :param eva: Evaluator, used to compute `metric` + :param validation: DataFrame, validation data set, used for evaluation. + :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. + :return: (int, float), an index into `epm` and the associated metric value. + """ + modelIter = est.fitMultiple(train, epm) + + def singleTask(): + index, model = next(modelIter) + metric = eva.evaluate(model.transform(validation, epm[index])) + return index, metric + + return [singleTask] * len(epm) + + class ParamGridBuilder(object): r""" Builder for a param grid used in grid search-based model selection. @@ -266,15 +288,9 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric - - currentFoldMetrics = pool.map(singleTrain, epm) - for j in range(numModels): - metrics[j] += (currentFoldMetrics[j] / nFolds) + tasks = _parallelFitTasks(est, train, eva, validation, epm) + for j, metric in pool.imap_unordered(lambda f: f(), tasks): + metrics[j] += (metric / nFolds) validation.unpersist() train.unpersist() @@ -523,13 +539,11 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric - + tasks = _parallelFitTasks(est, train, eva, validation, epm) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) - metrics = pool.map(singleTrain, epm) + metrics = [None] * numModels + for j, metric in pool.imap_unordered(lambda f: f(), tasks): + metrics[j] = metric train.unpersist() validation.unpersist() diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0..156603128d06 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef..b8d86cc098e9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ddd8df3b15bf..733e32bd825b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1374,7 +1374,8 @@ def hash(*cols): @ignore_unicode_prefix def concat(*cols): """ - Concatenates multiple input string columns together into a single string column. + Concatenates multiple input columns together into a single column. + If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() @@ -2093,9 +2094,14 @@ class PandasUDFType(object): def udf(f=None, returnType=StringType()): """Creates a user defined function (UDF). - .. note:: The user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> from pyspark.sql.types import IntegerType + >>> import random + >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions @@ -2208,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4e58bfb84364..49af1bcee5ef 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -333,7 +333,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -344,17 +344,17 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non or RDD of Strings storing CSV rows. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). - :param sep: sets the single character as a separator for each field and value. + :param sep: sets a single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, it uses the default value, ``UTF-8``. - :param quote: sets the single character used for escaping quoted values where the + :param quote: sets a single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an empty string. - :param escape: sets the single character used for escaping quotes inside an already + :param escape: sets a single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\``. - :param comment: sets the single character used for skipping lines beginning with this + :param comment: sets a single character used for skipping lines beginning with this character. By default (None), it is disabled. :param header: uses the first line as names of columns. If None is set, it uses the default value, ``false``. @@ -410,6 +410,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. :param multiLine: parse records, which may span multiple lines. If None is set, it uses the default value, ``false``. + :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for + the quote character. If None is set, the default value is + escape character when escape and quote characters are + different, ``\0`` otherwise. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -427,7 +431,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -814,7 +819,8 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None): + timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, + charToEscapeQuoteEscaping=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -829,12 +835,12 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). - :param sep: sets the single character as a separator for each field and value. If None is + :param sep: sets a single character as a separator for each field and value. If None is set, it uses the default value, ``,``. - :param quote: sets the single character used for escaping quoted values where the + :param quote: sets a single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If an empty string is set, it uses ``u0000`` (null character). - :param escape: sets the single character used for escaping quotes inside an already + :param escape: sets a single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\`` :param escapeQuotes: a flag indicating whether values containing quotes should always be enclosed in quotes. If None is set, it uses the default value @@ -860,6 +866,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from values being written should be skipped. If None is set, it uses the default value, ``true``. + :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for + the quote character. If None is set, the default value is + escape character when escape and quote characters are + different, ``\0`` otherwise.. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -868,7 +878,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, dateFormat=dateFormat, timestampFormat=timestampFormat, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, - ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace) + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d0aba28788ac..24ae3776a217 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -560,7 +560,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -572,17 +572,17 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param path: string, or list of strings, for input path(s). :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). - :param sep: sets the single character as a separator for each field and value. + :param sep: sets a single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, it uses the default value, ``UTF-8``. - :param quote: sets the single character used for escaping quoted values where the + :param quote: sets a single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an empty string. - :param escape: sets the single character used for escaping quotes inside an already + :param escape: sets a single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\``. - :param comment: sets the single character used for skipping lines beginning with this + :param comment: sets a single character used for skipping lines beginning with this character. By default (None), it is disabled. :param header: uses the first line as names of columns. If None is set, it uses the default value, ``false``. @@ -638,6 +638,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. :param multiLine: parse one record, which may span multiple lines. If None is set, it uses the default value, ``false``. + :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for + the quote character. If None is set, the default value is + escape character when escape and quote characters are + different, ``\0`` otherwise.. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -653,7 +657,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: @@ -788,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b977160af566..122a65b83aef 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,55 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -558,7 +607,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -566,7 +614,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -626,7 +673,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1290,7 +1336,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1359,7 +1404,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) @@ -3142,6 +3186,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime + from decimal import Decimal ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3158,11 +3203,15 @@ def setUpClass(cls): StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3190,10 +3239,11 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_unsupported_datatype(self): - schema = StructType([StructField("decimal", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.toPandas() def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3293,7 +3343,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) @@ -3317,11 +3367,11 @@ def test_createDataFrame_with_incorrect_schema(self): def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3344,7 +3394,7 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3357,6 +3407,31 @@ def test_schema_conversion_roundtrip(self): schema_rt = from_arrow_schema(arrow_schema) self.assertEquals(self.schema, schema_rt) + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): @@ -3506,6 +3581,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3514,6 +3601,7 @@ def test_vectorized_udf_basic(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, StringType()) @@ -3521,10 +3609,12 @@ def test_vectorized_udf_basic(self): long_f = pandas_udf(f, LongType()) float_f = pandas_udf(f, FloatType()) double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_boolean(self): @@ -3590,6 +3680,16 @@ def test_vectorized_udf_null_double(self): res = df.select(double_f(col('double'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_string(self): from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] @@ -3607,6 +3707,7 @@ def test_vectorized_udf_datatype_string(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, 'string') @@ -3614,12 +3715,32 @@ def test_vectorized_udf_datatype_string(self): long_f = pandas_udf(f, 'long') float_f = pandas_udf(f, 'float') double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') bool_f = pandas_udf(f, 'boolean') res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), ([3, 4],)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_null_array(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( @@ -3674,7 +3795,7 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) + f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() @@ -3713,12 +3834,12 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, DecimalType()) + f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('dt'))).collect() + df.select(f(col('map'))).collect() def test_vectorized_udf_null_date(self): from pyspark.sql.functions import pandas_udf, col @@ -3794,6 +3915,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col + import pandas as pd orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) try: @@ -3801,11 +3923,11 @@ def test_vectorized_udf_check_config(self): @pandas_udf(returnType=LongType()) def check_records_per_batch(x): - self.assertTrue(x.size <= 3) - return x + return pd.Series(x.size).repeat(x.size) - result = df.select(check_records_per_batch(col("id"))) - self.assertEqual(df.collect(), result.collect()) + result = df.select(check_records_per_batch(col("id"))).collect() + for (r,) in result: + self.assertTrue(r <= 3) finally: if orig_value is None: self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") @@ -3854,6 +3976,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): @@ -3977,7 +4126,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, - 'id long, v array', + 'id long, v map', PandasUDFType.GROUP_MAP ) @@ -4012,7 +4161,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + [StructField("id", LongType(), True), + StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 063264a89379..146e673ae975 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,7 +1617,7 @@ def to_arrow_type(dt): elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: - arrow_type = pa.decimal(dt.precision, dt.scale) + arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == DateType: @@ -1625,6 +1625,8 @@ def to_arrow_type(dt): elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == ArrayType: + arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type @@ -1665,6 +1667,8 @@ def from_arrow_type(at): spark_type = DateType() elif types.is_timestamp(at): spark_type = TimestampType() + elif types.is_list(at): + spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 123138117fdc..5e80ab916586 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,6 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType + self.deterministic = deterministic @property def returnType(self): @@ -129,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -137,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -161,5 +168,16 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper + + def asNondeterministic(self): + """ + Updates UserDefinedFunction to nondeterministic. + + .. versionadded:: 2.3 + """ + self.deterministic = False + return self diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index fb7d42a35d8f..08c34c6dccc5 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -118,7 +118,8 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion import pandas if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " + "however, your version was %s." % pandas.__version__) def require_minimum_pyarrow_version(): @@ -127,4 +128,5 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion import pyarrow if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process") + raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " + "however, your version was %s." % pyarrow.__version__) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index 5a975d050b0d..5de448114ece 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -20,6 +20,8 @@ from io import BytesIO else: from StringIO import StringIO +import warnings + from py4j.protocol import Py4JJavaError from pyspark.storagelevel import StorageLevel diff --git a/python/setup.py b/python/setup.py index 310670e697a8..251d4526d4dd 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2'] + 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -210,6 +210,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 45f527959cbe..e5d79d9a9d9d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -20,7 +20,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.network.util.ByteUnit private[spark] object Config extends Logging { @@ -132,20 +131,72 @@ private[spark] object Config extends Logging { val JARS_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pod.") + .doc("Location to download jars to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pod.") .stringConf .createWithDefault("/var/spark-data/spark-jars") val FILES_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pods.") + .doc("Location to download files to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pods.") .stringConf .createWithDefault("/var/spark-data/spark-files") + val INIT_CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.initContainer.image") + .doc("Image for the driver and executor's init-container for downloading dependencies.") + .stringConf + .createOptional + + val INIT_CONTAINER_MOUNT_TIMEOUT = + ConfigBuilder("spark.kubernetes.mountDependencies.timeout") + .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + + "locations into the driver and executor pods.") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(300) + + val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = + ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") + .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + + "executor pod.") + .intConf + .createWithDefault(5) + + val INIT_CONTAINER_REMOTE_JARS = + ConfigBuilder("spark.kubernetes.initContainer.remoteJars") + .doc("Comma-separated list of jar URIs to download in the init-container. This is " + + "calculated from spark.jars.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_REMOTE_FILES = + ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") + .doc("Comma-separated list of file URIs to download in the init-container. This is " + + "calculated from spark.files.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_NAME = + ConfigBuilder("spark.kubernetes.initContainer.configMapName") + .doc("Name of the config map to use in the init-container that retrieves submitted files " + + "for the executor.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = + ConfigBuilder("spark.kubernetes.initContainer.configMapKey") + .doc("Key for the entry in the init container config map for submitted files that " + + "corresponds to the properties for this init-container.") + .internal() + .stringConf + .createOptional + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -153,9 +204,11 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 0b91145405d3..111cb2a3b75e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -69,6 +69,17 @@ private[spark] object Constants { val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" + val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" + + // Bootstrapping dependencies with the init-container + val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" + val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" + val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" + val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" + val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" + val INIT_CONTAINER_PROPERTIES_FILE_PATH = + s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" + val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala new file mode 100644 index 000000000000..dfeccf9e2bd1 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Bootstraps an init-container for downloading remote dependencies. This is separated out from + * the init-container steps API because this component can be used to bootstrap init-containers + * for both the driver and executors. + */ +private[spark] class InitContainerBootstrap( + initContainerImage: String, + imagePullPolicy: String, + jarsDownloadPath: String, + filesDownloadPath: String, + configMapName: String, + configMapKey: String, + sparkRole: String, + sparkConf: SparkConf) { + + /** + * Bootstraps an init-container that downloads dependencies to be used by a main container. + */ + def bootstrapInitContainer( + original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { + val sharedVolumeMounts = Seq[VolumeMount]( + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withMountPath(jarsDownloadPath) + .build(), + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withMountPath(filesDownloadPath) + .build()) + + val customEnvVarKeyPrefix = sparkRole match { + case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY + case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." + case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") + } + val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { + case (key, value) => + new EnvVarBuilder() + .withName(key) + .withValue(value) + .build() + } + + val initContainer = new ContainerBuilder(original.initContainer) + .withName("spark-init") + .withImage(initContainerImage) + .withImagePullPolicy(imagePullPolicy) + .addAllToEnv(customEnvVars.asJava) + .addNewVolumeMount() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) + .endVolumeMount() + .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) + .build() + + val podWithBasicVolumes = new PodBuilder(original.pod) + .editSpec() + .addNewVolume() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .addNewItem() + .withKey(configMapKey) + .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .endSpec() + .build() + + val mainContainer = new ContainerBuilder(original.mainContainer) + .addToVolumeMounts(sharedVolumeMounts: _*) + .addNewEnv() + .withName(ENV_MOUNTED_FILES_DIR) + .withValue(filesDownloadPath) + .endEnv() + .build() + + PodWithDetachedInitContainer( + podWithBasicVolumes, + initContainer, + mainContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala similarity index 57% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index a38cf55fc3d5..37331d8bbf9b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -14,13 +14,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import java.io.File +import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} + +import org.apache.spark.SparkConf import org.apache.spark.util.Utils -private[spark] object KubernetesFileUtils { +private[spark] object KubernetesUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } + + /** + * Append the given init-container to a pod's list of init-containers. + * + * @param originalPodSpec original specification of the pod + * @param initContainer the init-container to add to the pod + * @return the pod with the init-container added to the list of InitContainers + */ + def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { + new PodBuilder(originalPodSpec) + .editOrNewSpec() + .addToInitContainers(initContainer) + .endSpec() + .build() + } /** * For the given collection of file URIs, resolves them as follows: @@ -47,6 +83,16 @@ private[spark] object KubernetesFileUtils { } } + /** + * Get from a given collection of file URIs the ones that represent remote files. + */ + def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { + uris.filter { uri => + val scheme = Utils.resolveURI(uri).getScheme + scheme != "file" && scheme != "local" + } + } + private def resolveFileUri( uri: String, fileDownloadPath: String, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala new file mode 100644 index 000000000000..c35e7db51d40 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} + +/** + * Bootstraps a driver or executor container or an init-container with needed secrets mounted. + */ +private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { + + /** + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. + * + * @param pod the pod into which the secret volumes are being added. + * @return the updated pod with the secret volumes added. + */ + def addSecretVolumes(pod: Pod): Pod = { + var podBuilder = new PodBuilder(pod) + secretNamesToMountPaths.keys.foreach { name => + podBuilder = podBuilder + .editOrNewSpec() + .addNewVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() + .endSpec() + } + + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { + var containerBuilder = new ContainerBuilder(container) + secretNamesToMountPaths.foreach { case (name, path) => + containerBuilder = containerBuilder + .addNewVolumeMount() + .withName(secretVolumeName(name)) + .withMountPath(path) + .endVolumeMount() + } + + containerBuilder.build() + } + + private def secretVolumeName(secretName: String): String = { + secretName + "-volume" + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala new file mode 100644 index 000000000000..0b79f8b12e80 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{Container, Pod} + +/** + * Represents a pod with a detached init-container (not yet added to the pod). + * + * @param pod the pod + * @param initContainer the init-container in the pod + * @param mainContainer the main container in the pod + */ +private[spark] case class PodWithDetachedInitContainer( + pod: Pod, + initContainer: Container, + mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 1e3f055e0576..c47e78cbf19e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -48,7 +48,7 @@ private[spark] object SparkKubernetesClientFactory { .map(new File(_)) .orElse(defaultServiceAccountToken) val oauthTokenValue = sparkConf.getOption(oauthTokenConf) - ConfigurationUtils.requireNandDefined( + KubernetesUtils.requireNandDefined( oauthTokenFile, oauthTokenValue, s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala new file mode 100644 index 000000000000..c0f08786b76a --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File +import java.util.concurrent.TimeUnit + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Process that fetches files from a resource staging server and/or arbitrary remote locations. + * + * The init-container can handle fetching files from any of those sources, but not all of the + * sources need to be specified. This allows for composing multiple instances of this container + * with different configurations for different download sources, or using the same container to + * download everything at once. + */ +private[spark] class SparkPodInitContainer( + sparkConf: SparkConf, + fileFetcher: FileFetcher) extends Logging { + + private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) + private implicit val downloadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) + + private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) + private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) + + private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) + private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) + + private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) + + def run(): Unit = { + logInfo(s"Downloading remote jars: $remoteJars") + downloadFiles( + remoteJars, + jarsDownloadDir, + s"Remote jars download directory specified at $jarsDownloadDir does not exist " + + "or is not a directory.") + + logInfo(s"Downloading remote files: $remoteFiles") + downloadFiles( + remoteFiles, + filesDownloadDir, + s"Remote files download directory specified at $filesDownloadDir does not exist " + + "or is not a directory.") + + downloadExecutor.shutdown() + downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) + } + + private def downloadFiles( + filesCommaSeparated: Option[String], + downloadDir: File, + errMessage: String): Unit = { + filesCommaSeparated.foreach { files => + require(downloadDir.isDirectory, errMessage) + Utils.stringToSeq(files).foreach { file => + Future[Unit] { + fileFetcher.fetchFile(file, downloadDir) + } + } + } + } +} + +private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { + + def fetchFile(uri: String, targetDir: File): Unit = { + Utils.fetchFile( + url = uri, + targetDir = targetDir, + conf = sparkConf, + securityMgr = securityManager, + hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), + timestamp = System.currentTimeMillis(), + useCache = false) + } +} + +object SparkPodInitContainer extends Logging { + + def main(args: Array[String]): Unit = { + logInfo("Starting init-container to download Spark application dependencies.") + val sparkConf = new SparkConf(true) + if (args.nonEmpty) { + Utils.loadDefaultSparkProperties(sparkConf, args(0)) + } + + val securityManager = new SparkSecurityManager(sparkConf) + val fileFetcher = new FileFetcher(sparkConf, securityManager) + new SparkPodInitContainer(sparkConf, fileFetcher).run() + logInfo("Finished downloading application dependencies.") + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala similarity index 53% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 1411e6f40b46..c9cc300d6556 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -21,25 +21,31 @@ import java.util.UUID import com.google.common.primitives.Longs import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock +import org.apache.spark.util.Utils /** - * Constructs the complete list of driver configuration steps to run to deploy the Spark driver. + * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to + * configure the Spark driver pod. The returned steps will be applied one by one in the given + * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication + * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to + * configure the driver init-container if one is needed, i.e., when there are remote dependencies + * to localize. */ -private[spark] class DriverConfigurationStepsOrchestrator( - namespace: String, +private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, launchTime: Long, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) { + sparkConf: SparkConf) { // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the @@ -49,13 +55,14 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") } - private val imagePullPolicy = submissionSparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val jarsDownloadPath = submissionSparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = submissionSparkConf.get(FILES_DOWNLOAD_LOCATION) + private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" + private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) + private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) - def getAllConfigurationSteps(): Seq[DriverConfigurationStep] = { - val driverCustomLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, + def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + @@ -64,11 +71,15 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + "operations.") + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + val allDriverLabels = driverCustomLabels ++ Map( SPARK_APP_ID_LABEL -> kubernetesAppId, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - val initialSubmissionStep = new BaseDriverConfigurationStep( + val initialSubmissionStep = new BasicDriverConfigurationStep( kubernetesAppId, kubernetesResourceNamePrefix, allDriverLabels, @@ -76,16 +87,16 @@ private[spark] class DriverConfigurationStepsOrchestrator( appName, mainClass, appArgs, - submissionSparkConf) + sparkConf) - val driverAddressStep = new DriverServiceBootstrapStep( + val serviceBootstrapStep = new DriverServiceBootstrapStep( kubernetesResourceNamePrefix, allDriverLabels, - submissionSparkConf, + sparkConf, new SystemClock) val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, kubernetesResourceNamePrefix) + sparkConf, kubernetesResourceNamePrefix) val additionalMainAppJar = if (mainAppResource.nonEmpty) { val mayBeResource = mainAppResource.get match { @@ -98,28 +109,62 @@ private[spark] class DriverConfigurationStepsOrchestrator( None } - val sparkJars = submissionSparkConf.getOption("spark.jars") + val sparkJars = sparkConf.getOption("spark.jars") .map(_.split(",")) .getOrElse(Array.empty[String]) ++ additionalMainAppJar.toSeq - val sparkFiles = submissionSparkConf.getOption("spark.files") + val sparkFiles = sparkConf.getOption("spark.files") .map(_.split(",")) .getOrElse(Array.empty[String]) - val maybeDependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Some(new DependencyResolutionStep( + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { + Seq(new DependencyResolutionStep( sparkJars, sparkFiles, jarsDownloadPath, filesDownloadPath)) } else { - None + Nil + } + + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { + val orchestrator = new InitContainerConfigOrchestrator( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + imagePullPolicy, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME, + sparkConf) + val bootstrapStep = new DriverInitContainerBootstrapStep( + orchestrator.getAllConfigurationSteps, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME) + + Seq(bootstrapStep) + } else { + Nil } Seq( initialSubmissionStep, - driverAddressStep, + serviceBootstrapStep, kubernetesCredentialsStep) ++ - maybeDependencyResolutionStep.toSeq + dependencyResolutionStep ++ + mountSecretsStep ++ + initContainerBootstrapStep + } + + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme != "local" + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 240a1144577b..5884348cb3e4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -80,22 +80,22 @@ private[spark] object ClientArguments { * spark.kubernetes.submission.waitAppCompletion is true. * * @param submissionSteps steps that collectively configure the driver - * @param submissionSparkConf the submission client Spark configuration + * @param sparkConf the submission client Spark configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete * @param appName the application name - * @param loggingPodStatusWatcher a watcher that monitors and logs the application status + * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( submissionSteps: Seq[DriverConfigurationStep], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - loggingPodStatusWatcher: LoggingPodStatusWatcher) extends Logging { + watcher: LoggingPodStatusWatcher) extends Logging { - private val driverJavaOptions = submissionSparkConf.get( + private val driverJavaOptions = sparkConf.get( org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) /** @@ -104,7 +104,7 @@ private[spark] class Client( * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(submissionSparkConf) + var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) // submissionSteps contain steps necessary to take, to resolve varying // client arguments that are passed in, created by orchestrator for (nextStep <- submissionSteps) { @@ -141,7 +141,7 @@ private[spark] class Client( kubernetesClient .pods() .withName(resolvedDriverPod.getMetadata.getName) - .watch(loggingPodStatusWatcher)) { _ => + .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { @@ -157,7 +157,7 @@ private[spark] class Client( if (waitForAppCompletion) { logInfo(s"Waiting for application $appName to finish...") - loggingPodStatusWatcher.awaitCompletion() + watcher.awaitCompletion() logInfo(s"Application $appName finished.") } else { logInfo(s"Deployed Spark application $appName into Kubernetes.") @@ -207,11 +207,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val master = sparkConf.get("spark.master").substring("k8s://".length) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None - val loggingPodStatusWatcher = new LoggingPodStatusWatcherImpl( - kubernetesAppId, loggingInterval) + val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val configurationStepsOrchestrator = new DriverConfigurationStepsOrchestrator( - namespace, + val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, launchTime, clientArguments.mainAppResource, @@ -228,12 +226,12 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - configurationStepsOrchestrator.getAllConfigurationSteps(), + orchestrator.getAllConfigurationSteps, sparkConf, kubernetesClient, waitForAppCompletion, appName, - loggingPodStatusWatcher) + watcher) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala similarity index 68% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index c335fcce4036..eca46b84c606 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -22,49 +22,46 @@ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarS import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} /** - * Represents the initial setup required for the driver. + * Performs basic configuration for the driver pod. */ -private[spark] class BaseDriverConfigurationStep( +private[spark] class BasicDriverConfigurationStep( kubernetesAppId: String, - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], imagePullPolicy: String, appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) extends DriverConfigurationStep { + sparkConf: SparkConf) extends DriverConfigurationStep { - private val kubernetesDriverPodName = submissionSparkConf.get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$kubernetesResourceNamePrefix-driver") + private val driverPodName = sparkConf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"$resourceNamePrefix-driver") - private val driverExtraClasspath = submissionSparkConf.get( - DRIVER_CLASS_PATH) + private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - private val driverContainerImage = submissionSparkConf + private val driverContainerImage = sparkConf .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) // CPU settings - private val driverCpuCores = submissionSparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = submissionSparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) + private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") + private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) // Memory settings - private val driverMemoryMiB = submissionSparkConf.get( - DRIVER_MEMORY) - private val driverMemoryString = submissionSparkConf.get( - DRIVER_MEMORY.key, - DRIVER_MEMORY.defaultValueString) - private val memoryOverheadMiB = submissionSparkConf + private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) + private val driverMemoryString = sparkConf.get( + DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) + private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val driverContainerMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => @@ -74,15 +71,13 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val driverCustomAnnotations = ConfigurationUtils - .parsePrefixedKeyValuePairs( - submissionSparkConf, - KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + " Spark bookkeeping operations.") - val driverCustomEnvs = submissionSparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq + val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq .map { env => new EnvVarBuilder() .withName(env._1) @@ -90,10 +85,10 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val allDriverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) + val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - val nodeSelector = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) val driverCpuQuantity = new QuantityBuilder(false) .withAmount(driverCpuCores) @@ -102,7 +97,7 @@ private[spark] class BaseDriverConfigurationStep( .withAmount(s"${driverMemoryMiB}Mi") .build() val driverMemoryLimitQuantity = new QuantityBuilder(false) - .withAmount(s"${driverContainerMemoryWithOverheadMiB}Mi") + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) @@ -124,7 +119,7 @@ private[spark] class BaseDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) @@ -142,9 +137,9 @@ private[spark] class BaseDriverConfigurationStep( val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() - .withName(kubernetesDriverPodName) + .withName(driverPodName) .addToLabels(driverLabels.asJava) - .addToAnnotations(allDriverAnnotations.asJava) + .addToAnnotations(driverAnnotations.asJava) .endMetadata() .withNewSpec() .withRestartPolicy("Never") @@ -153,9 +148,9 @@ private[spark] class BaseDriverConfigurationStep( .build() val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, kubernetesDriverPodName) + .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, kubernetesResourceNamePrefix) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) driverSpec.copy( driverPod = baseDriverPod, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index 44e0ecffc0e9..d4b83235b4e3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -21,7 +21,8 @@ import java.io.File import io.fabric8.kubernetes.api.model.ContainerBuilder import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, KubernetesFileUtils} +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** * Step that configures the classpath, spark.jars, and spark.files for the driver given that the @@ -31,21 +32,22 @@ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], sparkFiles: Seq[String], jarsDownloadPath: String, - localFilesDownloadPath: String) extends DriverConfigurationStep { + filesDownloadPath: String) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesFileUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesFileUtils.resolveFileUris( - sparkFiles, localFilesDownloadPath) - val sparkConfResolvedSparkDependencies = driverSpec.driverSparkConf.clone() + val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) + val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + + val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.jars", resolvedSparkJars.mkString(",")) + sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) } if (resolvedSparkFiles.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.files", resolvedSparkFiles.mkString(",")) + sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - val resolvedClasspath = KubernetesFileUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val driverContainerWithResolvedClasspath = if (resolvedClasspath.nonEmpty) { + + val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) + val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() .withName(ENV_MOUNTED_CLASSPATH) @@ -55,8 +57,9 @@ private[spark] class DependencyResolutionStep( } else { driverSpec.driverContainer } + driverSpec.copy( - driverContainer = driverContainerWithResolvedClasspath, - driverSparkConf = sparkConfResolvedSparkDependencies) + driverContainer = resolvedDriverContainer, + driverSparkConf = sparkConf) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala index c99c0436cf25..17614e040e58 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** - * Represents a step in preparing the Kubernetes driver. + * Represents a step in configuring the Spark driver pod. */ private[spark] trait DriverConfigurationStep { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala new file mode 100644 index 000000000000..9fb3dafdda54 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringWriter +import java.util.Properties + +import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} + +/** + * Configures the driver init-container that localizes remote dependencies into the driver pod. + * It applies the given InitContainerConfigurationSteps in the given order to produce a final + * InitContainerSpec that is then used to configure the driver pod with the init-container attached. + * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries + * configuration properties for the init-container. + */ +private[spark] class DriverInitContainerBootstrapStep( + steps: Seq[InitContainerConfigurationStep], + configMapName: String, + configMapKey: String) + extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + var initContainerSpec = InitContainerSpec( + properties = Map.empty[String, String], + driverSparkConf = Map.empty[String, String], + initContainer = new ContainerBuilder().build(), + driverContainer = driverSpec.driverContainer, + driverPod = driverSpec.driverPod, + dependentResources = Seq.empty[HasMetadata]) + for (nextStep <- steps) { + initContainerSpec = nextStep.configureInitContainer(initContainerSpec) + } + + val configMap = buildConfigMap( + configMapName, + configMapKey, + initContainerSpec.properties) + val resolvedDriverSparkConf = driverSpec.driverSparkConf + .clone() + .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) + .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) + .setAll(initContainerSpec.driverSparkConf) + val resolvedDriverPod = KubernetesUtils.appendInitContainer( + initContainerSpec.driverPod, initContainerSpec.initContainer) + + driverSpec.copy( + driverPod = resolvedDriverPod, + driverContainer = initContainerSpec.driverContainer, + driverSparkConf = resolvedDriverSparkConf, + otherKubernetesResources = + driverSpec.otherKubernetesResources ++ + initContainerSpec.dependentResources ++ + Seq(configMap)) + } + + private def buildConfigMap( + configMapName: String, + configMapKey: String, + config: Map[String, String]): ConfigMap = { + val properties = new Properties() + config.foreach { entry => + properties.setProperty(entry._1, entry._2) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName " + + s"and config map key: $configMapKey") + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .endMetadata() + .addToData(configMapKey, propertiesWriter.toString) + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala new file mode 100644 index 000000000000..91e9a9f21133 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +/** + * A driver configuration step for mounting user-specified secrets onto user-specified paths. + * + * @param bootstrap a utility actually handling mounting of the secrets. + */ +private[spark] class DriverMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) + driverSpec.copy( + driverPod = pod, + driverContainer = container + ) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index 696d11f15ed9..eb594e4f16ec 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -32,21 +32,22 @@ import org.apache.spark.util.Clock * ports should correspond to the ports that the executor will reach the pod at for RPC. */ private[spark] class DriverServiceBootstrapStep( - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, clock: Clock) extends DriverConfigurationStep with Logging { + import DriverServiceBootstrapStep._ override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(submissionSparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + "address is managed and set to the driver pod's IP address.") - require(submissionSparkConf.getOption(DRIVER_HOST_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + "managed via a Kubernetes service.") - val preferredServiceName = s"$kubernetesResourceNamePrefix$DRIVER_SVC_POSTFIX" + val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { preferredServiceName } else { @@ -58,8 +59,8 @@ private[spark] class DriverServiceBootstrapStep( shorterServiceName } - val driverPort = submissionSparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = submissionSparkConf.getInt( + val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = sparkConf.getInt( org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) val driverService = new ServiceBuilder() .withNewMetadata() @@ -81,7 +82,7 @@ private[spark] class DriverServiceBootstrapStep( .endSpec() .build() - val namespace = submissionSparkConf.get(KUBERNETES_NAMESPACE) + val namespace = sparkConf.get(KUBERNETES_NAMESPACE) val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala new file mode 100644 index 000000000000..01469853dacc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils + +/** + * Performs basic configuration for the driver init-container with most of the work delegated to + * the given InitContainerBootstrap. + */ +private[spark] class BasicInitContainerConfigurationStep( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + bootstrap: InitContainerBootstrap) + extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { + val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) + val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) + val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) + } else { + Map() + } + val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) + } else { + Map() + } + + val baseInitContainerConfig = Map( + JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, + FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ + remoteJarsConf ++ + remoteFilesConf + + val bootstrapped = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + spec.driverPod, + spec.initContainer, + spec.driverContainer)) + + spec.copy( + initContainer = bootstrapped.initContainer, + driverContainer = bootstrapped.mainContainer, + driverPod = bootstrapped.pod, + properties = spec.properties ++ baseInitContainerConfig) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala new file mode 100644 index 000000000000..f2c29c7ce107 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to + * configure the driver init-container. The returned steps will be applied in the given order to + * produce a final InitContainerSpec that is used to construct the driver init-container in + * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., + * when there are remote application dependencies to localize. + */ +private[spark] class InitContainerConfigOrchestrator( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + imagePullPolicy: String, + configMapName: String, + configMapKey: String, + sparkConf: SparkConf) { + + private val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + + def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { + val initContainerBootstrap = new InitContainerBootstrap( + initContainerImage, + imagePullPolicy, + jarsDownloadPath, + filesDownloadPath, + configMapName, + configMapKey, + SPARK_POD_DRIVER_ROLE, + sparkConf) + val baseStep = new BasicInitContainerConfigurationStep( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + initContainerBootstrap) + + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + // Mount user-specified driver secrets also into the driver's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The driver's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + + Seq(baseStep) ++ mountSecretsStep + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala new file mode 100644 index 000000000000..0372ad527095 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +/** + * Represents a step in configuring the driver init-container. + */ +private[spark] trait InitContainerConfigurationStep { + + def configureInitContainer(spec: InitContainerSpec): InitContainerSpec +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala new file mode 100644 index 000000000000..0daa7b95e8aa --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap + +/** + * An init-container configuration step for mounting user-specified secrets onto user-specified + * paths. + * + * @param bootstrap a utility actually handling mounting of the secrets + */ +private[spark] class InitContainerMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala new file mode 100644 index 000000000000..b52c343f0c0e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} + +/** + * Represents a specification of the init-container for the driver pod. + * + * @param properties properties that should be set on the init-container + * @param driverSparkConf Spark configuration properties that will be carried back to the driver + * @param initContainer the init-container object + * @param driverContainer the driver container object + * @param driverPod the driver pod object + * @param dependentResources resources the init-container depends on to work + */ +private[spark] case class InitContainerSpec( + properties: Map[String, String], + driverSparkConf: Map[String, String], + initContainer: Container, + driverContainer: Container, + driverPod: Pod, + dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 70226157dd68..066d7e9f70ca 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,35 +21,35 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} import org.apache.spark.util.Utils /** - * A factory class for configuring and creating executor pods. + * A factory class for bootstrapping and creating executor pods with the given bootstrapping + * components. + * + * @param sparkConf Spark configuration + * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto + * user-specified paths into the executor container + * @param initContainerBootstrap an optional component for bootstrapping the executor init-container + * if one is needed, i.e., when there are remote dependencies to + * localize + * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified + * secrets onto user-specified paths into the executor + * init-container */ -private[spark] trait ExecutorPodFactory { - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod -} - -private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) - extends ExecutorPodFactory { +private[spark] class ExecutorPodFactory( + sparkConf: SparkConf, + mountSecretsBootstrap: Option[MountSecretsBootstrap], + initContainerBootstrap: Option[InitContainerBootstrap], + initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( @@ -64,11 +64,11 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") private val executorAnnotations = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) private val nodeSelector = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) @@ -94,7 +94,10 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - override def createExecutorPod( + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( executorId: String, applicationId: String, driverUrl: String, @@ -198,7 +201,7 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .endSpec() .build() - val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val containerWithLimitCores = executorLimitCores.map { limitCores => val executorCpuLimitQuantity = new QuantityBuilder(false) .withAmount(limitCores) .build() @@ -209,9 +212,35 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .build() }.getOrElse(executorContainer) - new PodBuilder(executorPod) + val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = + mountSecretsBootstrap.map { bootstrap => + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) + }.getOrElse((executorPod, containerWithLimitCores)) + + val (bootstrappedPod, bootstrappedContainer) = + initContainerBootstrap.map { bootstrap => + val podWithInitContainer = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + maybeSecretsMountedPod, + new ContainerBuilder().build(), + maybeSecretsMountedContainer)) + + val (pod, mayBeSecretsMountedInitContainer) = + initContainerMountSecretsBootstrap.map { bootstrap => + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) + }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) + + val bootstrappedPod = KubernetesUtils.appendInitContainer( + pod, mayBeSecretsMountedInitContainer) + + (bootstrappedPod, podWithInitContainer.mainContainer) + }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) + + new PodBuilder(bootstrappedPod) .editSpec() - .addToContainers(containerWithExecutorLimitCores) + .addToContainers(bootstrappedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b8bb152d1791..a942db6ae02d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,9 +21,9 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.util.ThreadUtils @@ -45,6 +45,59 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { val sparkConf = sc.getConf + val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) + val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) + + if (initContainerConfigMap.isEmpty) { + logWarning("The executor's init-container config map is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + if (initContainerConfigMapKey.isEmpty) { + logWarning("The executor's init-container config map key is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + // Only set up the bootstrap if they've provided both the config map key and the config map + // name. The config map might not be provided if init-containers aren't being used to + // bootstrap dependencies. + val initContainerBootstrap = for { + configMap <- initContainerConfigMap + configMapKey <- initContainerConfigMapKey + } yield { + val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + new InitContainerBootstrap( + initContainerImage, + sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), + sparkConf.get(JARS_DOWNLOAD_LOCATION), + sparkConf.get(FILES_DOWNLOAD_LOCATION), + configMap, + configMapKey, + SPARK_POD_EXECUTOR_ROLE, + sparkConf) + } + + val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } + // Mount user-specified executor secrets also into the executor's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The executor's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && + executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, @@ -54,7 +107,12 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val executorPodFactory = new ExecutorPodFactory( + sparkConf, + mountSecretBootstrap, + initContainerBootstrap, + initContainerMountSecretsBootstrap) + val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 54% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 01717479fddd..16780584a674 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,28 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.deploy.k8s -import org.apache.spark.SparkConf +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, Pod} -private[spark] object ConfigurationUtils { +private[spark] object SecretVolumeUtils { - /** - * Extract and parse Spark configuration properties with a given name prefix and - * return the result as a Map. Keys must not have more than one value. - * - * @param sparkConf Spark configuration - * @param prefix the given property name prefix - * @return a Map storing the configuration property keys and values - */ - def parsePrefixedKeyValuePairs( - sparkConf: SparkConf, - prefix: String): Map[String, String] = { - sparkConf.getAllWithPrefix(prefix).toMap + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { - opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala new file mode 100644 index 000000000000..e0f29ecd0fb5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File +import java.util.UUID + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.Utils + +class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { + + private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") + private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") + + private var downloadJarsDir: File = _ + private var downloadFilesDir: File = _ + private var downloadJarsSecretValue: String = _ + private var downloadFilesSecretValue: String = _ + private var fileFetcher: FileFetcher = _ + + override def beforeAll(): Unit = { + downloadJarsSecretValue = Files.toString( + new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) + downloadFilesSecretValue = Files.toString( + new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) + } + + before { + downloadJarsDir = Utils.createTempDir() + downloadFilesDir = Utils.createTempDir() + fileFetcher = mock[FileFetcher] + } + + after { + downloadJarsDir.delete() + downloadFilesDir.delete() + } + + test("Downloads from remote server should invoke the file fetcher") { + val sparkConf = getSparkConfForRemoteFileDownloads + val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) + initContainerUnderTest.run() + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) + } + + private def getSparkConfForRemoteFileDownloads: SparkConf = { + new SparkConf(true) + .set(INIT_CONTAINER_REMOTE_JARS, + "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") + .set(INIT_CONTAINER_REMOTE_FILES, + "http://localhost:9000/file.txt") + .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) + .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) + } + + private def createTempFile(extension: String): String = { + val dir = Utils.createTempDir() + val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") + Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) + file.getAbsolutePath + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala similarity index 51% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 98f9f27da5cd..f193b1f4d366 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config.DRIVER_CONTAINER_IMAGE +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ -class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { +class DriverConfigOrchestratorSuite extends SparkFunSuite { - private val NAMESPACE = "default" private val DRIVER_IMAGE = "driver-image" + private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" private val LAUNCH_TIME = 975256L private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Some(mainAppResource), @@ -45,7 +47,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], classOf[DependencyResolutionStep] @@ -55,8 +57,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Option.empty, @@ -66,16 +67,62 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep] ) } + test("Submission steps with an init-container.") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverInitContainerBootstrapStep]) + } + + test("Submission steps with driver secrets to mount") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverMountSecretsStep]) + } + private def validateStepTypes( - orchestrator: DriverConfigurationStepsOrchestrator, + orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps() + val steps = orchestrator.getAllConfigurationSteps assert(steps.size === types.size) assert(steps.map(_.getClass) === types) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala similarity index 95% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index f7c1b3142cf7..8ee629ac8ddc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -class BaseDriverConfigurationStepSuite extends SparkFunSuite { +class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val APP_ID = "spark-app-id" private val RESOURCE_NAME_PREFIX = "spark" @@ -33,7 +33,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -52,7 +52,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - val submissionStep = new BaseDriverConfigurationStep( + val submissionStep = new BasicDriverConfigurationStep( APP_ID, RESOURCE_NAME_PREFIX, DRIVER_LABELS, @@ -82,7 +82,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala new file mode 100644 index 000000000000..758871e2ba35 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringReader +import java.util.Properties + +import scala.collection.JavaConverters._ + +import com.google.common.collect.Maps +import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} +import org.apache.spark.util.Utils + +class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { + + private val CONFIG_MAP_NAME = "spark-init-config-map" + private val CONFIG_MAP_KEY = "spark-init-config-map-key" + + test("The init container bootstrap step should use all of the init container steps") { + val baseDriverSpec = KubernetesDriverSpec( + driverPod = new PodBuilder().build(), + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + val initContainerSteps = Seq( + FirstTestInitContainerConfigurationStep, + SecondTestInitContainerConfigurationStep) + val bootstrapStep = new DriverInitContainerBootstrapStep( + initContainerSteps, + CONFIG_MAP_NAME, + CONFIG_MAP_KEY) + + val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) + + assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === + FirstTestInitContainerConfigurationStep.additionalLabels) + val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala + assert(additionalDriverEnv.size === 1) + assert(additionalDriverEnv.head.getName === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) + assert(additionalDriverEnv.head.getValue === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) + + assert(preparedDriverSpec.otherKubernetesResources.size === 2) + assert(preparedDriverSpec.otherKubernetesResources.contains( + FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) + assert(preparedDriverSpec.otherKubernetesResources.exists { + case configMap: ConfigMap => + val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME + val configMapData = configMap.getData.asScala + val hasCorrectNumberOfEntries = configMapData.size == 1 + val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) + val initContainerProperties = new Properties() + Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { + initContainerProperties.load(_) + } + val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala + val expectedInitContainerProperties = Map( + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) + val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties + hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties + + case _ => false + }) + + val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers + assert(initContainers.size() === 1) + val initContainerEnv = initContainers.get(0).getEnv.asScala + assert(initContainerEnv.size === 1) + assert(initContainerEnv.head.getName === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) + assert(initContainerEnv.head.getValue === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) + + val expectedSparkConf = Map( + INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) + assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) + } +} + +private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + + val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") + val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" + val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" + val additionalKubernetesResource = new SecretBuilder() + .withNewMetadata() + .withName("test-secret") + .endMetadata() + .addToData("secret-key", "secret-value") + .build() + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val driverPod = new PodBuilder(initContainerSpec.driverPod) + .editOrNewMetadata() + .addToLabels(additionalLabels.asJava) + .endMetadata() + .build() + val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) + .addNewEnv() + .withName(additionalMainContainerEnvKey) + .withValue(additionalMainContainerEnvValue) + .endEnv() + .build() + initContainerSpec.copy( + driverPod = driverPod, + driverContainer = mainContainer, + dependentResources = initContainerSpec.dependentResources ++ + Seq(additionalKubernetesResource)) + } +} + +private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" + val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" + val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" + val additionalInitContainerPropertyValue = "testvalue" + val additionalDriverSparkConfKey = "spark.driver.testkey" + val additionalDriverSparkConfValue = "spark.driver.testvalue" + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val initContainer = new ContainerBuilder(initContainerSpec.initContainer) + .addNewEnv() + .withName(additionalInitContainerEnvKey) + .withValue(additionalInitContainerEnvValue) + .endEnv() + .build() + val initContainerProperties = initContainerSpec.properties ++ + Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) + val driverSparkConf = initContainerSpec.driverSparkConf ++ + Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) + initContainerSpec.copy( + initContainer = initContainer, + properties = initContainerProperties, + driverSparkConf = driverSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala new file mode 100644 index 000000000000..960d0bda1d01 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +class DriverMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" + + test("mounts all given secrets") { + val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) + val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) + val driverPodWithSecretsMounted = configuredDriverSpec.driverPod + val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) + } + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.containerHasVolume( + driverContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala new file mode 100644 index 000000000000..4553f9f6b1d4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ + +class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val POD_LABEL = Map("bootstrap" -> "true") + private val INIT_CONTAINER_NAME = "init-container" + private val DRIVER_CONTAINER_NAME = "driver-container" + + @Mock + private var podAndInitContainerBootstrap : InitContainerBootstrap = _ + + before { + MockitoAnnotations.initMocks(this) + when(podAndInitContainerBootstrap.bootstrapInitContainer( + any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { + override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { + val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) + pod.copy( + pod = new PodBuilder(pod.pod) + .withNewMetadata() + .addToLabels("bootstrap", "true") + .endMetadata() + .withNewSpec().endSpec() + .build(), + initContainer = new ContainerBuilder() + .withName(INIT_CONTAINER_NAME) + .build(), + mainContainer = new ContainerBuilder() + .withName(DRIVER_CONTAINER_NAME) + .build() + )}}) + } + + test("additionalDriverSparkConf with mix of remote files and jars") { + val baseInitStep = new BasicInitContainerConfigurationStep( + SPARK_JARS, + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + podAndInitContainerBootstrap) + val expectedDriverSparkConf = Map( + JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, + INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", + INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") + val initContainerSpec = InitContainerSpec( + Map.empty[String, String], + Map.empty[String, String], + new Container(), + new Container(), + new Pod, + Seq.empty[HasMetadata]) + val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) + assert(expectedDriverSparkConf === returnContainerSpec.properties) + assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) + assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) + assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala new file mode 100644 index 000000000000..20f2e5bc15df --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class InitContainerConfigOrchestratorSuite extends SparkFunSuite { + + private val DOCKER_IMAGE = "init-container" + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" + private val CUSTOM_LABEL_KEY = "customLabel" + private val CUSTOM_LABEL_VALUE = "customLabelValue" + private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" + private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("including basic configuration step") { + val sparkConf = new SparkConf(true) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.lengthCompare(1) == 0) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + } + + test("including step to mount user-specified secrets") { + val sparkConf = new SparkConf(false) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.length === 2) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala new file mode 100644 index 000000000000..7ac0bde80dfe --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} + +class InitContainerMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("mounts all given secrets") { + val baseInitContainerSpec = InitContainerSpec( + Map.empty, + Map.empty, + new ContainerBuilder().build(), + new ContainerBuilder().build(), + new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), + Seq.empty) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) + val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( + baseInitContainerSpec) + val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => + assert(SecretVolumeUtils.containerHasVolume( + initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 3a55d7cb37b1..884da8aabd88 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -18,15 +18,19 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{Pod, _} -import org.mockito.MockitoAnnotations +import io.fabric8.kubernetes.api.model._ +import org.mockito.{AdditionalAnswers, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" @@ -54,7 +58,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactoryImpl(baseConf) + val factory = new ExecutorPodFactory(baseConf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -85,7 +89,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -97,7 +101,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -108,6 +112,76 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor secrets get mounted") { + val conf = baseConf.clone() + + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + val factory = new ExecutorPodFactory( + conf, + Some(secretsBootstrap), + None, + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName + === "secret1-volume") + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) + .getMountPath === "/var/secret1") + + // check volume mounted. + assert(executor.getSpec.getVolumes.size() === 1) + assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") + + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container bootstrap step adds an init container") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + + val factory = new ExecutorPodFactory( + conf, + None, + Some(initContainerBootstrap), + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getInitContainers.size() === 1) + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container with secrets mount bootstrap") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + + val factory = new ExecutorPodFactory( + conf, + Some(secretsBootstrap), + Some(initContainerBootstrap), + Some(secretsBootstrap)) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) + assert(executor.getSpec.getInitContainers.size() === 1) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) + + checkOwnerReferences(executor, driverPodUid) + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 9b682f8673c6..45fbcd9cd0de 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 168cd4cb6c57..0f806cf7e148 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile new file mode 100644 index 000000000000..047056ab2633 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM spark-base + +# If this docker file is being used in the context of building your images from a Spark distribution, the docker build +# command should be invoked from the top level directory of the Spark distribution. E.g.: +# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . + +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a8..da1d6b9e161c 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 191415a2578b..53f5f61cca48 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -92,6 +92,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private[this] var stopCalled: Boolean = false private val launcherBackend = new LauncherBackend() { + override protected def conf: SparkConf = sc.conf + override protected def onStopRequest(): Unit = { stopSchedulerBackend() setState(SparkAppHandle.State.KILLED) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d7263..4d5e3bb04367 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 3781b261a038..15328d08b3b5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -100,6 +100,8 @@ private[spark] class Client( private var amKeytabFileName: String = null private val launcherBackend = new LauncherBackend() { + override protected def conf: SparkConf = sparkConf + override def onStopRequest(): Unit = { if (isClusterMode && appId != null) { yarnClient.killApplication(appId) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 000000000000..695a82f3583e --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index ab0005d7b53a..061f653b97b7 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -95,7 +95,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index 4546e98dc207..b9532597419a 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,28 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t spark-base \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d5..6daf01d98426 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 64ab01ca5740..d18542b188f7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -294,7 +294,7 @@ public void setNullAt(int ordinal) { assertIndexIsValid(ordinal); BitSetMethods.set(baseObject, baseOffset + 8, ordinal); - /* we assume the corrresponding column was already 0 or + /* we assume the corresponding column was already 0 or will be set to 0 later by the caller side */ } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 000000000000..f0f66bae245f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 10b237fb22b9..35b35110e491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -53,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -71,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -96,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -151,7 +164,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: - TypeCoercion.typeCoercionRules ++ + TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, @@ -177,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -601,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -665,14 +678,18 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { - // Remove analysis barrier if any. - val right = EliminateBarriers(originalRight) + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") right.collect { + // For `AnalysisBarrier`, recursively de-duplicate its child. + case oldVersion: AnalysisBarrier + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = dedupRight(left, oldVersion.child) + (oldVersion, AnalysisBarrier(newVersion)) + // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -693,7 +710,7 @@ class Analyzer( (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) @@ -710,10 +727,10 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - originalRight + right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { + right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { @@ -723,7 +740,6 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - AnalysisBarrier(newRight) } } @@ -958,7 +974,8 @@ class Analyzer( protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, - throws: Boolean = false) = { + throws: Boolean = false): Expression = { + if (expr.resolved) return expr // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. @@ -1079,100 +1096,74 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) - val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) - } else if (newOrder != order) { - s.copy(order = newOrder) - } else { - s - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - // in Sort - case ae: AnalysisException => s + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) + val ordering = newOrder.map(_.asInstanceOf[SortOrder]) + if (child.output == newChild.output) { + s.copy(order = ordering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = ordering, child = newChild) + Project(child.output, newSort) } - case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newCond = resolveExpressionRecursively(cond, child) - val requiredAttrs = newCond.references.filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away. - Project(child.output, - Filter(newCond, addMissingAttr(child, missingAttrs))) - } else if (newCond != cond) { - f.copy(condition = newCond) - } else { - f - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - case ae: AnalysisException => f + case f @ Filter(cond, child) if !f.resolved && child.resolved => + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) + if (child.output == newChild.output) { + f.copy(condition = newCond.head) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(newCond.head, newChild) + Project(child.output, newFilter) } } - /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ - private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { - if (missingAttrs.isEmpty) { - return AnalysisBarrier(plan) - } - plan match { - case p: Project => - val missing = missingAttrs -- p.child.outputSet - Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case a: Aggregate => - // all the missing attributes should be grouping expressions - // TODO: push down AggregateExpression - missingAttrs.foreach { attr => - if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { - throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") - } - } - val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case g: Generate => - // If join is false, we will convert it to true for getting from the child the missing - // attributes that its child might have or could have. - val missing = missingAttrs -- g.child.outputSet - g.copy(join = true, child = addMissingAttr(g.child, missing)) - case d: Distinct => - throw new AnalysisException(s"Can't add $missingAttrs to $d") - case u: UnaryNode => - u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) - case other => - throw new AnalysisException(s"Can't add $missingAttrs to $other") - } - } - - /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ - @tailrec - private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { - val resolved = resolveExpression(expr, plan) - if (resolved.resolved) { - resolved + private def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + if (exprs.forall(_.resolved)) { + // All given expressions are resolved, no need to continue anymore. + (exprs, plan) } else { plan match { - case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => - resolveExpressionRecursively(resolved, u.child) - case other => resolved + // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via + // its child. + case barrier: AnalysisBarrier => + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) + (newExprs, AnalysisBarrier(newChild)) + + case p: Project => + val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpression(_, other)), other) } } } @@ -1292,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1404,20 +1395,18 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) - case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + case Filter(cond, AnalysisBarrier(agg: Aggregate)) => + apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")() :: Nil, + Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1436,7 +1425,7 @@ class Analyzer( // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !agg.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1450,22 +1439,22 @@ class Analyzer( // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(aggregate.output, + Project(agg.output, Filter(transformedAggregateFilter, - aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) + agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) } else { - filter + f } } else { - filter + f } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. - case ae: AnalysisException => filter + case ae: AnalysisException => f } - case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -1475,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] @@ -1603,7 +1593,7 @@ class Analyzer( resolvedGenerator = Generate( generator, - join = projectList.size > 1, // Only join if there are other expressions in SELECT. + unrequiredChildIndex = Nil, outer = outer, qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6894aed15c16..bbcec5627bd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -608,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper { // allows to have correlation under it // but must not host any outer references. // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => + // Generator with requiredChildOutput.isEmpty is treated as Category 4. + case g: Generate if g.requiredChildOutput.nonEmpty => failOnInvalidOuterReference(g) // Category 4: Any other operators not in the above 3 categories diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2f306f58b7b8..e8669c4637d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -45,13 +46,15 @@ import org.apache.spark.sql.types._ */ object TypeCoercion { - val typeCoercionRules = + def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion :: WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: + ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -324,9 +327,11 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right) => + case a @ BinaryArithmetic(left @ StringType(), right) + if right.dataType != CalendarIntervalType => a.makeCopy(Array(Cast(left, DoubleType), right)) - case a @ BinaryArithmetic(left, right @ StringType()) => + case a @ BinaryArithmetic(left, right @ StringType()) + if left.dataType != CalendarIntervalType => a.makeCopy(Array(left, Cast(right, DoubleType))) // For equality between string and timestamp we cast the string to a timestamp @@ -658,6 +663,56 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Concat]] children to expected ones. + * + * If `spark.sql.function.concatBinaryAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } + } + } + + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bbe41cf8f15..20216087b015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf * view resolution, in this way, we are able to get the correct view column ordering and * omit the extra columns that we don't require); * 1.2. Else set the child output attributes to `queryOutput`. - * 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match, + * 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match, * try to up cast and alias the attribute in `queryOutput` to the attribute in the view output. * 3. Add a Project over the child, with the new output generated by the previous steps. * If the view output doesn't have the same number of columns neither with the child output, nor diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7c100afcd738..59cb26d5e6c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -359,12 +359,12 @@ package object dsl { def generate( generator: Generator, - join: Boolean = false, + unrequiredChildIndex: Seq[Int] = Nil, outer: Boolean = false, alias: Option[String] = None, outputNames: Seq[String] = Nil): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, - outputNames.map(UnresolvedAttribute(_)), logicalPlan) + Generate(generator, unrequiredChildIndex, outer, + alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5279d4127896..f2de4c8e30be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -181,7 +181,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") + s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") } } @@ -206,6 +206,59 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +650,88 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +743,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 7facb9dad9a7..a45854a3b514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } if (result.length == 0) { null @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d6eccadcfb63..2c714c228e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -190,7 +190,7 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { @@ -352,7 +352,7 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStateInitCode.distinct + val initCodes = mutableStateInitCode.distinct.map(_ + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 142dfb02be0a..b444c3a7be92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -40,7 +40,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( - s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + "type of predicate expression in If should be boolean, " + + s"not ${predicate.dataType.simpleString}") } else if (!trueValue.dataType.sameType(falseValue.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 69af7a250a5a..4f4d49166e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -155,8 +155,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + - s"Argument $i (${children(i).dataType})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + + s"Argument $i (${children(i).dataType.simpleString})") } } TypeCheckResult.TypeCheckSuccess @@ -249,7 +249,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function explode should be array or map type, not ${child.dataType}") + "input to function explode should be array or map type, " + + s"not ${child.dataType.simpleString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -378,7 +379,8 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should be array of struct type, not ${child.dataType}") + s"input to function $prettyName should be array of struct type, " + + s"not ${child.dataType.simpleString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4af813456b79..64da9bb9cdec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -51,7 +51,7 @@ trait InvokeLike extends Expression with NonSQLExpression { * * - generate codes for argument. * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. - * - avoid some of nullabilty checking which are not needed because the expression is not + * - avoid some of nullability checking which are not needed because the expression is not * nullable. * - when needNullCheck == true, short circuit if we found one of arguments is null because * preparing rest of arguments can be skipped in the case. @@ -193,7 +193,8 @@ case class StaticInvoke( * @param targetObject An expression that will return the object to call the method on. * @param functionName The name of the method to call. * @param dataType The expected return type of the function. - * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param arguments An optional list of expressions, whose evaluation will be passed to the + * function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. * @param returnNullable When false, indicating the invoked method will always return diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f4ee3d10f3f4..b469f5cb7586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -195,7 +195,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType} != ${mismatchOpt.get.dataType}") + s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") } } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index fa5425c77ebb..f3e8f6de5897 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - // inline mutable state since not many Like operations in a task val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr)); - ${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($escapeFunc($rightStr)); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - // inline mutable state since not many RLike operations in a task val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($rightStr); - ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($rightStr); + ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) } @@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - if (!$rep.equals(${termLastReplacementInUTF8})) { + if (!$rep.equals($termLastReplacementInUTF8)) { // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); } - $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); - java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); + $classNameStringBuffer $termResult = new $classNameStringBuffer(); + java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); - while (${matcher}.find()) { - ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); + while ($matcher.find()) { + $matcher.appendReplacement($termResult, $termLastReplacement); } - ${matcher}.appendTail(${termResult}); - ${ev.value} = UTF8String.fromString(${termResult}.toString()); - ${termResult} = null; + $matcher.appendTail($termResult); + ${ev.value} = UTF8String.fromString($termResult.toString()); + $termResult = null; $setEvNotNull """ }) @@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - java.util.regex.Matcher ${matcher} = - ${termPattern}.matcher($subject.toString()); - if (${matcher}.find()) { - java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); - if (${matchResult}.group($idx) == null) { + java.util.regex.Matcher $matcher = + $termPattern.matcher($subject.toString()); + if ($matcher.find()) { + java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + if ($matchResult.group($idx) == null) { ${ev.value} = UTF8String.EMPTY_UTF8; } else { - ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + ${ev.value} = UTF8String.fromString($matchResult.group($idx)); } $setEvNotNull } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c02c41db1668..e004bfc6af47 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,11 +24,10 @@ import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -38,7 +37,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** - * An expression that concatenates multiple input strings into a single string. + * An expression that concatenates multiple inputs into a single output. + * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. * If any input is null, concat returns null. */ @ExpressionDescription( @@ -48,17 +48,37 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { +case class Concat(children: Seq[Expression]) extends Expression { - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType + private lazy val isBinaryMode: Boolean = dataType == BinaryType + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + if (isBinaryMode) { + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + } else { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -73,17 +93,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + val (concatenator, initCode) = if (isBinaryMode) { + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + } else { + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = ("UTF8String[]", args) :: Nil) + extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" - UTF8String[] $args = new UTF8String[${evals.length}]; + $initCode $codes - UTF8String ${ev.value} = UTF8String.concat($args); + ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } @@ -241,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -277,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -305,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -331,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 220cc4f885d7..dd13d9a3bba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -70,9 +70,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType}' used in the order specification does " + - s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + - "range frame.") + s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + "specification does not match the data type " + + s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +251,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType}' does not match " + + s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6a4d1e997c3c..0d4b02c6e7d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -456,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] { f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) - case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = prunedChild(g.child, g.references)) - // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => - p.copy(child = g.copy(join = false)) + // prune unrequired references + case p @ Project(_, g: Generate) if p.references != g.outputSet => + val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references + val newChild = prunedChild(g.child, requiredAttrs) + val unrequired = g.generator.references -- p.references + val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) + .map(_._2) + p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _) => @@ -802,15 +805,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -832,14 +835,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(partitionAttrs) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -851,7 +854,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic) if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) @@ -875,13 +878,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } case filter @ Filter(condition, watermark: EventTimeWatermark) => - // We can only push deterministic predicates which don't reference the watermark attribute. - // We could in theory span() only on determinism and pull out deterministic predicates - // on the watermark separately. But it seems unnecessary and a bit confusing to not simply - // use the prefix as we do for nondeterminism in other cases. - - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span( - p => p.deterministic && !p.references.contains(watermark.eventTime)) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p => + p.deterministic && !p.references.contains(watermark.eventTime) + } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduceLeft(And) @@ -922,14 +921,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // come from grandchild. // TODO: non-deterministic predicates could be pushed through some operators that do not change // the rows. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(grandchild.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val newChild = insertFilter(pushDown.reduceLeft(And)) @@ -972,23 +971,19 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions or filter predicates (on a given join's output) into three * categories based on the attributes required to evaluate them. Note that we explicitly exclude - * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or + * non-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or * canEvaluateInRight to prevent pushing these predicates on either side of the join. * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { - // Note: In order to ensure correctness, it's important to not change the relative ordering of - // any deterministic expression that follows a non-deterministic expression. To achieve this, - // we only consider pushing down those expressions that precede the first non-deterministic - // expression in the condition. - val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) + val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 85295aff1980..1c0b7bd80680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -505,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -529,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } @@ -645,6 +655,12 @@ object CombineConcats extends Rule[LogicalPlan] { stack.pop() match { case Concat(children) => stack.pushAll(children.reverse) + // If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly + // have `Concat`s with binary output. Since `TypeCoercion` casts them into strings, + // we need to handle the case to combine all nested `Concat`s. + case c @ Cast(Concat(children), StringType, _) => + val newChildren = children.map { e => c.copy(child = e) } + stack.pushAll(newChildren.reverse) case child => flattened += child } @@ -652,8 +668,14 @@ object CombineConcats extends Rule[LogicalPlan] { Concat(flattened) } + private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists { + case c: Concat => true + case c @ Cast(Concat(children), StringType, _) => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { - case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) => + case concat: Concat if hasNestedConcats(concat) => flattenConcats(concat) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7651d11ee65a..bdc357d54a87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val expressions = expressionList(ctx.expression) Generate( UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), - join = true, + unrequiredChildIndex = Nil, outer = ctx.OUTER != null, Some(ctx.tblName.getText.toLowerCase), ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e..89347f4b1f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5..ff2a0ec58856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38de..9c0a30a47f83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index cd474551622d..95e099c340af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -73,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * their output. * * @param generator the generator expression - * @param join when true, each output row is implicitly joined with the input tuple that produced - * it. + * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer. + * It's used as an optimization for omitting data generation that will + * be discarded next by a projection. + * A common use case is when we explode(array(..)) and are interested + * only in the exploded data and not in the original array. before this + * optimization the array got duplicated for each of its elements, + * causing O(n^^2) memory consumption. (see [SPARK-21657]) * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) @@ -83,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend */ case class Generate( generator: Generator, - join: Boolean, + unrequiredChildIndex: Seq[Int], outer: Boolean, qualifier: Option[String], generatorOutput: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - /** The set of all attributes produced by this node. */ - def generatedSet: AttributeSet = AttributeSet(generatorOutput) + lazy val requiredChildOutput: Seq[Attribute] = { + val unrequiredSet = unrequiredChildIndex.toSet + child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1) + } override lazy val resolved: Boolean = { generator.resolved && @@ -114,9 +121,7 @@ case class Generate( nullableOutput } - def output: Seq[Attribute] = { - if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput - } + def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput } case class Filter(condition: Expression, child: LogicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 71e852afe065..d793f77413d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -89,29 +89,29 @@ object EstimationUtils { } /** - * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * For simplicity we use Double to unify operations for data types whose min/max values can be * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). * The two methods below are the contract of conversion. */ - def toDecimal(value: Any, dataType: DataType): Decimal = { + def toDouble(value: Any, dataType: DataType): Double = { dataType match { - case _: NumericType | DateType | TimestampType => Decimal(value.toString) - case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + case _: NumericType | DateType | TimestampType => value.toString.toDouble + case BooleanType => if (value.asInstanceOf[Boolean]) 1 else 0 } } - def fromDecimal(dec: Decimal, dataType: DataType): Any = { + def fromDouble(double: Double, dataType: DataType): Any = { dataType match { - case BooleanType => dec.toLong == 1 - case DateType => dec.toInt - case TimestampType => dec.toLong - case ByteType => dec.toByte - case ShortType => dec.toShort - case IntegerType => dec.toInt - case LongType => dec.toLong - case FloatType => dec.toFloat - case DoubleType => dec.toDouble - case _: DecimalType => dec + case BooleanType => double.toInt == 1 + case DateType => double.toInt + case TimestampType => double.toLong + case ByteType => double.toByte + case ShortType => double.toShort + case IntegerType => double.toInt + case LongType => double.toLong + case FloatType => double.toFloat + case DoubleType => double + case _: DecimalType => Decimal(double) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 850dd1ba724a..4cc32de2d32d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -31,7 +31,7 @@ case class FilterEstimation(plan: Filter) extends Logging { private val childStats = plan.child.stats - private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) + private val colStatsMap = ColumnStatsMap(childStats.attributeStats) /** * Returns an option of Statistics for a Filter logical plan node. @@ -47,7 +47,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // Estimate selectivity of this filter predicate, and update column stats if needed. // For not-supported condition, set filter selectivity to a conservative estimate 100% - val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1)) + val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(1.0) val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) val newColStats = if (filteredRowCount == 0) { @@ -79,17 +79,16 @@ case class FilterEstimation(plan: Filter) extends Logging { * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateFilterSelectivity(condition: Expression, update: Boolean = true) - : Option[BigDecimal] = { + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(BigDecimal(1)) - val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(BigDecimal(1)) + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) Some(percent1 * percent2) case Or(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(BigDecimal(1)) - val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(BigDecimal(1)) + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) Some(percent1 + percent2 - (percent1 * percent2)) // Not-operator pushdown @@ -131,7 +130,7 @@ case class FilterEstimation(plan: Filter) extends Logging { * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateSingleCondition(condition: Expression, update: Boolean): Option[BigDecimal] = { + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { case l: Literal => evaluateLiteral(l) @@ -225,17 +224,17 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateNullCheck( attr: Attribute, isNull: Boolean, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) val rowCountValue = childStats.rowCount.get - val nullPercent: BigDecimal = if (rowCountValue == 0) { + val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble } if (update) { @@ -271,7 +270,7 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -305,13 +304,12 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateEquality( attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. @@ -334,7 +332,7 @@ case class FilterEstimation(plan: Filter) extends Logging { if (colStat.histogram.isEmpty) { // returns 1/ndv if there is no histogram - Some(1.0 / BigDecimal(ndv)) + Some(1.0 / colStat.distinctCount.toDouble) } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -354,7 +352,7 @@ case class FilterEstimation(plan: Filter) extends Logging { * @param literal a literal value (or constant) * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateLiteral(literal: Literal): Option[BigDecimal] = { + def evaluateLiteral(literal: Literal): Option[Double] = { literal match { case Literal(null, _) => Some(0.0) case FalseLiteral => Some(0.0) @@ -379,7 +377,7 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateInSet( attr: Attribute, hSet: Set[Any], - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -403,8 +401,8 @@ case class FilterEstimation(plan: Filter) extends Logging { return Some(0.0) } - val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) - val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) + val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) @@ -425,7 +423,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0)) + Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0)) } /** @@ -443,21 +441,17 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] - val max = statsInterval.max.toBigDecimal - val min = statsInterval.min.toBigDecimal - val ndv = BigDecimal(colStat.distinctCount) + val max = statsInterval.max + val min = statsInterval.min + val ndv = colStat.distinctCount.toDouble // determine the overlapping degree between predicate interval and column's interval - val numericLiteral = if (literal.dataType == BooleanType) { - if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) - } else { - BigDecimal(literal.value.toString) - } + val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => (numericLiteral <= min, numericLiteral > max) @@ -469,7 +463,7 @@ case class FilterEstimation(plan: Filter) extends Logging { (numericLiteral > max, numericLiteral <= min) } - var percent = BigDecimal(1) + var percent = 1.0 if (noOverlap) { percent = 0.0 } else if (completeOverlap) { @@ -518,8 +512,6 @@ case class FilterEstimation(plan: Filter) extends Logging { val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min - var newNdv = ceil(ndv * percent) - if (newNdv < 1) newNdv = 1 op match { case _: GreaterThan | _: GreaterThanOrEqual => @@ -528,8 +520,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = - colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = ceil(ndv * percent), + min = newMin, max = newMax, nullCount = 0) colStatsMap.update(attr, newStats) } @@ -543,13 +535,13 @@ case class FilterEstimation(plan: Filter) extends Logging { */ private def computeEqualityPossibilityByHistogram( literal: Literal, colStat: ColumnStat): Double = { - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val datum = EstimationUtils.toDouble(literal.value, literal.dataType) val histogram = colStat.histogram.get // find bins where column's current min and max locate. Note that a column's [min, max] // range may change due to another condition applied earlier. - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType) + val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType) // compute how many bins the column's current valid range [min, max] occupies. val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( @@ -574,13 +566,13 @@ case class FilterEstimation(plan: Filter) extends Logging { */ private def computeComparisonPossibilityByHistogram( op: BinaryComparison, literal: Literal, colStat: ColumnStat): Double = { - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val datum = EstimationUtils.toDouble(literal.value, literal.dataType) val histogram = colStat.histogram.get // find bins where column's current min and max locate. Note that a column's [min, max] // range may change due to another condition applied earlier. - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType) + val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType) // compute how many bins the column's current valid range [min, max] occupies. val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( @@ -643,7 +635,7 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attrLeft: Attribute, attrRight: Attribute, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) @@ -726,7 +718,7 @@ case class FilterEstimation(plan: Filter) extends Logging { ) } - var percent = BigDecimal(1) + var percent = 1.0 if (noOverlap) { percent = 0.0 } else if (completeOverlap) { @@ -740,11 +732,9 @@ case class FilterEstimation(plan: Filter) extends Logging { // Need to adjust new min/max after the filter condition is applied val ndvLeft = BigDecimal(colStatLeft.distinctCount) - var newNdvLeft = ceil(ndvLeft * percent) - if (newNdvLeft < 1) newNdvLeft = 1 + val newNdvLeft = ceil(ndvLeft * percent) val ndvRight = BigDecimal(colStatRight.distinctCount) - var newNdvRight = ceil(ndvRight * percent) - if (newNdvRight < 1) newNdvRight = 1 + val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max var newMinLeft = colStatLeft.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index 0caaf796a3b6..f46b4ed764e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -26,10 +26,10 @@ trait ValueInterval { def contains(l: Literal): Boolean } -/** For simplicity we use decimal to unify operations of numeric intervals. */ -case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval { +/** For simplicity we use double to unify operations of numeric intervals. */ +case class NumericValueInterval(min: Double, max: Double) extends ValueInterval { override def contains(l: Literal): Boolean = { - val lit = EstimationUtils.toDecimal(l.value, l.dataType) + val lit = EstimationUtils.toDouble(l.value, l.dataType) min <= lit && max >= lit } } @@ -56,8 +56,8 @@ object ValueInterval { case _ if min.isEmpty || max.isEmpty => new NullValueInterval() case _ => NumericValueInterval( - min = EstimationUtils.toDecimal(min.get, dataType), - max = EstimationUtils.toDecimal(max.get, dataType)) + min = EstimationUtils.toDouble(min.get, dataType), + max = EstimationUtils.toDouble(max.get, dataType)) } def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { @@ -84,8 +84,8 @@ object ValueInterval { // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max - (Some(EstimationUtils.fromDecimal(newMin, dt)), - Some(EstimationUtils.fromDecimal(newMax, dt))) + (Some(EstimationUtils.fromDouble(newMin, dt)), + Some(EstimationUtils.fromDouble(newMax, dt))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6a..b013add9c977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84fe4bb711a4..5c61f10bb71a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -323,11 +325,14 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -336,8 +341,8 @@ object SQLConf { .createWithDefault(true) val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") - .doc("Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.") + .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") .booleanConf .createWithDefault(false) @@ -364,8 +369,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) @@ -1044,6 +1051,18 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") + .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1060,6 +1079,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1081,6 +1118,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) @@ -1378,6 +1421,13 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 5dcd653e9b34..52a7ebdafd7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -869,6 +869,114 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } + test("type coercion for Concat") { + val rule = TypeCoercion.ConcatCoercion(conf) + + ruleTest(rule, + Concat(Seq(Literal("ab"), Literal("cde"))), + Concat(Seq(Literal("ab"), Literal("cde")))) + ruleTest(rule, + Concat(Seq(Literal(null), Literal("abc"))), + Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Concat(Seq(Literal(1), Literal("234"))), + Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Concat(Seq(Literal("1"), Literal("234".getBytes()))), + Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), + Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(Decimal(10)))), + Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) + } + } + + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a..1445bb8a97d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,57 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e2..97ddbeba2c5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala index 10479630f3f9..30e3bc9fb577 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch /** - * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]]. + * Unit test suite for the count-min sketch SQL aggregate function [[CountMinSketchAgg]]. */ class CountMinSketchAggSuite extends SparkFunSuite { private val childExpression = BoundReference(0, IntegerType, nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 77e4eff26c69..3f41f4b14409 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest { CollapseProject) :: Nil } - test("Column pruning for Generate when Generate.join = false") { - val input = LocalRelation('a.int, 'b.array(StringType)) + test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) - val query = input.generate(Explode('b), join = false).analyze + val query = + input + .generate(Explode('c), outputNames = "explode" :: Nil) + .select('c, 'explode) + .analyze val optimized = Optimize.execute(query) - val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze + val correctAnswer = + input + .select('c) + .generate(Explode('c), outputNames = "explode" :: Nil) + .analyze comparePlans(optimized, correctAnswer) } - test("Column pruning for Generate when Generate.join = true") { - val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + test("Fill Generate.unrequiredChildIndex if possible") { + val input = LocalRelation('b.array(StringType)) val query = input - .generate(Explode('c), join = true, outputNames = "explode" :: Nil) - .select('a, 'explode) + .generate(Explode('b), outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .select('a, 'c) - .generate(Explode('c), join = true, outputNames = "explode" :: Nil) - .select('a, 'explode) + .generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2), + outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) .analyze comparePlans(optimized, correctAnswer) } - test("Turn Generate.join to false if possible") { - val input = LocalRelation('b.array(StringType)) + test("Another fill Generate.unrequiredChildIndex if possible") { + val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string) val query = input - .generate(Explode('b), join = true, outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil) + .select('a, 'c1, 'explode) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .generate(Explode('b), join = false, outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .select('a, 'c1, 'c2) + .generate(Explode(CreateArray(Seq('c1, 'c2))), + unrequiredChildIndex = Seq(2), + outputNames = "explode" :: Nil) .analyze comparePlans(optimized, correctAnswer) @@ -246,7 +256,7 @@ class ColumnPruningSuite extends PlanTest { x.select('a) .sortBy(SortOrder('a, Ascending)).analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) // push down invalid val originalQuery1 = { @@ -261,7 +271,7 @@ class ColumnPruningSuite extends PlanTest { .sortBy(SortOrder('a, Ascending)) .select('b).analyze - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + comparePlans(optimized1, correctAnswer1) } test("Column pruning on Window with useless aggregate functions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala index 412e199dfaae..441c15340a77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StringType class CombineConcatsSuite extends PlanTest { @@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest { comparePlans(actual, correctAnswer) } + def str(s: String): Literal = Literal(s) + def binary(s: String): Literal = Literal(s.getBytes) + test("combine nested Concat exprs") { - def str(s: String): Literal = Literal(s, StringType) assertEquivalent( Concat( Concat(str("a") :: str("b") :: Nil) :: @@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest { Nil), Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) } + + test("combine string and binary exprs") { + assertEquivalent( + Concat( + Concat(str("a") :: str("b") :: Nil) :: + Concat(binary("c") :: binary("d") :: Nil) :: + Nil), + Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 641824e6698f..85a5e979f602 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -504,7 +504,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(analysis.EliminateSubqueryAliases(originalQuery.analyze), optimized) + comparePlans(originalQuery.analyze, optimized) } test("joins: conjunctive predicates") { @@ -523,7 +523,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("joins: conjunctive predicates #2") { @@ -542,7 +542,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("joins: conjunctive predicates #3") { @@ -566,7 +566,7 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("joins: push down where clause into left anti join") { @@ -581,7 +581,7 @@ class FilterPushdownSuite extends PlanTest { x.where("x.a".attr > 10) .join(y, LeftAnti, Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("joins: only push down join conditions to the right of a left anti join") { @@ -598,7 +598,7 @@ class FilterPushdownSuite extends PlanTest { LeftAnti, Some("x.b".attr === "y.b".attr && "x.a".attr > 10)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("joins: only push down join conditions to the right of an existence join") { @@ -616,7 +616,7 @@ class FilterPushdownSuite extends PlanTest { ExistenceJoin(fillerVal), Some("x.a".attr > 1)) .analyze - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + comparePlans(optimized, correctAnswer) } val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) @@ -624,14 +624,14 @@ class FilterPushdownSuite extends PlanTest { test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('b >= 5) && ('a > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) - .generate(Explode('c_arr), true, false, Some("arr")).analyze + .generate(Explode('c_arr), alias = Some("arr")).analyze } comparePlans(optimized, correctAnswer) @@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest { test("generate: non-deterministic predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where('a + Rand(10).as("rnd") > 6 && 'col > 6) .analyze } @@ -659,14 +659,14 @@ class FilterPushdownSuite extends PlanTest { val generator = Explode('c_arr) val originalQuery = { testRelationWithArrayType - .generate(generator, true, false, Some("arr")) + .generate(generator, alias = Some("arr")) .where(('b >= 5) && ('c > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val referenceResult = { testRelationWithArrayType .where('b >= 5) - .generate(generator, true, false, Some("arr")) + .generate(generator, alias = Some("arr")) .where('c > 6).analyze } @@ -687,7 +687,7 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('col > 6) || ('b > 5)).analyze } val optimized = Optimize.execute(originalQuery) @@ -831,9 +831,9 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( - testRelation.where('a === 2L), - testRelation2.where('d === 2L))) - .where('b + Rand(10).as("rnd") === 3 && 'c > 5L) + testRelation.where('a === 2L && 'c > 5L), + testRelation2.where('d === 2L && 'f > 5L))) + .where('b + Rand(10).as("rnd") === 3) .analyze comparePlans(optimized, correctAnswer) @@ -1134,12 +1134,13 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - // Verify that all conditions preceding the first non-deterministic condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 && "x.a".attr === Rand(10) && "y.b".attr === 5)) - val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), - condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) + val correctAnswer = + x.where("x.a".attr === 5).join(y.where("y.a".attr === 5 && "y.b".attr === 5), + condition = Some("x.a".attr === Rand(10))) // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. // TODO support nondeterministic expressions in join condition. @@ -1147,16 +1148,16 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: no pushdown on watermark attribute") { + test("watermark pushdown: no pushdown on watermark attribute #1") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('b, interval, testRelation) .where('a === 5 && 'b === 10 && 'c === 5) val correctAnswer = EventTimeWatermark( - 'b, interval, testRelation.where('a === 5)) - .where('b === 10 && 'c === 5) + 'b, interval, testRelation.where('a === 5 && 'c === 5)) + .where('b === 10) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1165,7 +1166,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown for nondeterministic filter") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === Rand(10) && 'c === 5) @@ -1180,7 +1181,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: full pushdown") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === 10) @@ -1191,15 +1192,15 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: empty pushdown") { + test("watermark pushdown: no pushdown on watermark attribute #2") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down - // by the optimizer and others are not. val originalQuery = EventTimeWatermark('a, interval, testRelation) .where('a === 5 && 'b === 10) + val correctAnswer = EventTimeWatermark( + 'a, interval, testRelation.where('b === 10)).where('a === 5) - comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze, + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a..c28844642aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec7..a0708bf7eee9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 97733a754ccc..ccd9d8dd4d21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -118,7 +117,7 @@ class JoinOptimizationSuite extends PlanTest { queryAnswers foreach { queryAnswerPair => val optimized = Optimize.execute(queryAnswerPair._1.analyze) - comparePlans(optimized, analysis.EliminateSubqueryAliases(queryAnswerPair._2.analyze)) + comparePlans(optimized, queryAnswerPair._2.analyze) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d34a83c42c67..812bfdd7bb88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view explode(x) expl as x", table("t") - .generate(explode, join = true, outer = false, Some("expl"), Seq("x")) + .generate(explode, alias = Some("expl"), outputNames = Seq("x")) .select(star())) // Multiple lateral views @@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest { |lateral view explode(x) expl |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, table("t") - .generate(explode, join = true, outer = false, Some("expl"), Seq.empty) - .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z")) + .generate(explode, alias = Some("expl")) + .generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = Seq("q", "z")) .select(star())) // Multi-Insert lateral views. - val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x")) + val from = table("t1").generate(explode, alias = Some("expl"), outputNames = Seq("x")) assertEqual( """from t1 |lateral view explode(x) expl as x @@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest { |where s < 10 """.stripMargin, Union(from - .generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z")) + .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z")) .select(star()) .insertInto("t2"), from.where('s < 10).select(star()).insertInto("t3"))) @@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest { val expected = table("t") .generate( UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)), - join = true, - outer = false, - Some("posexpl"), - Seq("x", "y")) + alias = Some("posexpl"), + outputNames = Seq("x", "y")) .select(star()) assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3ba180860c32..c120863152a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -31,7 +31,6 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; -import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -96,7 +95,7 @@ public class VectorizedColumnReader { private final OriginalType originalType; // The timezone conversion to apply to int96 timestamps. Null if no conversion. private final TimeZone convertTz; - private final static TimeZone UTC = DateTimeUtils.TimeZoneUTC(); + private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC(); public VectorizedColumnReader( ColumnDescriptor descriptor, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 14f2a58d638f..cd745b1f0e4e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -79,8 +79,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private boolean[] missingColumns; /** - * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to workaround - * incompatibilities between different engines when writing timestamp values. + * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to + * workaround incompatibilities between different engines when writing timestamp values. */ private TimeZone convertTz = null; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e..b5cbe8e2839b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe..70057a9def6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e1..d2ae32b06f83 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 0b5b6ac675f2..3cb020d2e083 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -19,9 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; -import java.util.List; -import java.util.Map; - /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java similarity index 81% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index ae4f85820649..3136cee1f655 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; -import org.apache.spark.sql.sources.v2.reader.ContinuousReader; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; /** @@ -38,5 +39,8 @@ public interface ContinuousReadSupport extends DataSourceV2 { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - ContinuousReader createContinuousReader(Optional schema, String checkpointLocation, DataSourceV2Options options); + ContinuousReader createContinuousReader( + Optional schema, + String checkpointLocation, + DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java similarity index 82% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java index 362d5f52b4d0..dee493cadb71 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; @@ -37,9 +39,9 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link DataSourceV2Writer} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 442cad029d21..3c87a3db6824 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java index 63640779b955..53ffa95ae0f4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; @@ -39,7 +41,7 @@ public interface MicroBatchWriteSupport extends BaseStreamingSink { * @param queryId A unique string for the writing query. It's possible that there are many writing * queries running at the same time, and the returned {@link DataSourceV2Writer} * can use this id to distinguish itself from others. - * @param epochId The uniquenumeric ID of the batch within this writing query. This is an + * @param epochId The unique numeric ID of the batch within this writing query. This is an * incrementing counter representing a consistent set of data; the same batch may * be started multiple times in failure recovery scenarios, but it will always * contain the same records. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index 11b99a93f149..ca9a290e97a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -15,11 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; - -import java.io.IOException; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 34141d6cd85f..f0b205869ed6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index bd15c07d87f6..70ff75680603 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.Offset; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index ce1c48974205..60b87f2ac075 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; /** * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. @@ -42,7 +42,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of @Override public boolean equals(Object obj) { if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { - return this.json().equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); + return this.json() + .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index 07826b668847..eca0085c8a8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; import java.io.Serializable; @@ -26,5 +26,4 @@ * These offsets must be serializable. */ public interface PartitionOffset extends Serializable { - } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java similarity index 87% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java index 618f47ed79ca..723395bd1e96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java @@ -15,9 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** * A {@link DataSourceV2Writer} for use with continuous stream processing. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 528f66f342dc..708333213f3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } @@ -326,7 +268,8 @@ private abstract static class ArrowVectorAccessor { this.vector = vector; } - final boolean isNullAt(int rowId) { + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { return vector.isNull(rowId); } @@ -589,6 +532,16 @@ private static class ArrayAccessor extends ArrowVectorAccessor { this.accessor = vector; } + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + @Override final int getArrayLength(int rowId) { return accessor.getInnerValueCountAt(rowId); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd..d1196e1299fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec..0d89a52e7a4f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa67972..9ae1c6d9993f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 95c0d09873d6..3c6656dec77c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,16 +24,16 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { - // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // The data for this row. + // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index 594e747a8d3a..b13850c30149 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -32,7 +32,7 @@ stroke-width: 1px; } -/* Hightlight the SparkPlan node name */ +/* Highlight the SparkPlan node name */ #plan-viz-graph svg text :first-child { font-weight: bold; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c43ee91294a2..e8d683a578f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -517,17 +517,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * * You can set the following CSV-specific options to deal with CSV files: *
      - *
    • `sep` (default `,`): sets the single character as a separator for each + *
    • `sep` (default `,`): sets a single character as a separator for each * field and value.
    • *
    • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding * type.
    • - *
    • `quote` (default `"`): sets the single character used for escaping quoted values where + *
    • `quote` (default `"`): sets a single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to * set not `null` but an empty string. This behaviour is different from * `com.databricks.spark.csv`.
    • - *
    • `escape` (default `\`): sets the single character used for escaping quotes inside + *
    • `escape` (default `\`): sets a single character used for escaping quotes inside * an already quoted value.
    • - *
    • `comment` (default empty string): sets the single character used for skipping lines + *
    • `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for + * escaping the escape for the quote character. The default value is escape character when escape + * and quote characters are different, `\0` otherwise.
    • + *
    • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
    • *
    • `header` (default `false`): uses the first line as names of columns.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7ccda0ad36d1..3304f368e105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -264,7 +264,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } } @@ -594,13 +594,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * * You can set the following CSV-specific option(s) for writing CSV files: *
        - *
      • `sep` (default `,`): sets the single character as a separator for each + *
      • `sep` (default `,`): sets a single character as a separator for each * field and value.
      • - *
      • `quote` (default `"`): sets the single character used for escaping quoted values where + *
      • `quote` (default `"`): sets a single character used for escaping quoted values where * the separator can be part of the value. If an empty string is set, it uses `u0000` * (null character).
      • - *
      • `escape` (default `\`): sets the single character used for escaping quotes inside + *
      • `escape` (default `\`): sets a single character used for escaping quotes inside * an already quoted value.
      • + *
      • `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for + * escaping the escape for the quote character. The default value is escape character when escape + * and quote characters are different, `\0` otherwise.
      • *
      • `escapeQuotes` (default `true`): a flag indicating whether values containing * quotes should always be enclosed in quotes. Default is to escape all values containing * a quote character.
      • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 209b800fdc6f..77e571272920 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2095,7 +2095,7 @@ class Dataset[T] private[sql]( val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) withPlan { - Generate(generator, join = true, outer = false, + Generate(generator, unrequiredChildIndex = Nil, outer = false, qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2136,7 +2136,7 @@ class Dataset[T] private[sql]( val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) withPlan { - Generate(generator, join = true, outer = false, + Generate(generator, unrequiredChildIndex = Nil, outer = false, qualifier = None, generatorOutput = Nil, planWithBarrier) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3ff476147b8b..f94baef39dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} @@ -41,8 +42,6 @@ import org.apache.spark.util.Utils * spark.udf * }}} * - * @note The user-defined functions must be deterministic. - * * @since 1.3.0 */ @InterfaceStability.Stable @@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | pythonIncludes: ${udf.func.pythonIncludes} | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} + | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)} + | udfDeterministic: ${udf.udfDeterministic} """.stripMargin) functionRegistry.createOrReplaceTempFunction(name, udf.builder) @@ -109,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -143,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -688,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -703,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -718,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -733,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -748,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -763,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -778,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -793,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -808,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -823,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -838,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -853,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -868,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -883,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -898,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -913,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -928,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -943,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -958,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -973,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -988,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1003,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1018,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292b..5617046e1396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index e1562befe14f..0c2c4a1a9100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -47,8 +47,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * terminate(). * * @param generator the generator expression - * @param join when true, each output row is implicitly joined with the input tuple that produced - * it. + * @param requiredChildOutput required attributes from child's output * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which @@ -57,19 +56,13 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In */ case class GenerateExec( generator: Generator, - join: Boolean, + requiredChildOutput: Seq[Attribute], outer: Boolean, generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryExecNode with CodegenSupport { - override def output: Seq[Attribute] = { - if (join) { - child.output ++ generatorOutput - } else { - generatorOutput - } - } + override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -85,11 +78,19 @@ case class GenerateExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) - val rows = if (join) { + val rows = if (requiredChildOutput.nonEmpty) { + + val pruneChildForResult: InternalRow => InternalRow = + if (child.outputSet == AttributeSet(requiredChildOutput)) { + identity + } else { + UnsafeProjection.create(requiredChildOutput, child.output) + } + val joinedRow = new JoinedRow iter.flatMap { row => - // we should always set the left (child output) - joinedRow.withLeft(row) + // we should always set the left (required child output) + joinedRow.withLeft(pruneChildForResult(row)) val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { joinedRow.withRight(generatorNullRow) :: Nil @@ -136,7 +137,7 @@ case class GenerateExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { // Add input rows to the values when we are joining - val values = if (join) { + val values = if (requiredChildOutput.nonEmpty) { input } else { Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index daff3c49e751..ef1bb1c2a446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -138,7 +138,7 @@ case class SortExec( // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - // inline mutable state since not many Sort operations in a task + // Inline mutable state since not many Sort operations in a task sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", v => s"$v = $thisPlan.createSorter();", forceInline = true) val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972..d3cfd2a1ffbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8c6c324d456c..910294853c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -158,45 +158,65 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def smallerSide = if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - val buildRight = canBuildRight && right.stats.hints.broadcast - val buildLeft = canBuildLeft && left.stats.hints.broadcast - - if (buildRight && buildLeft) { + if (canBuildRight && canBuildLeft) { // Broadcast smaller side base on its estimated physical size // if both sides have broadcast hint smallerSide - } else if (buildRight) { + } else if (canBuildRight) { BuildRight - } else if (buildLeft) { + } else if (canBuildLeft) { BuildLeft - } else if (canBuildRight && canBuildLeft) { + } else { // for the last default broadcast nested loop join smallerSide - } else { - throw new AnalysisException("Can not decide which side to broadcast for this join") } } + private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + buildLeft || buildRight + } + + private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + broadcastSide(buildLeft, buildRight, left, right) + } + + private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + buildLeft || buildRight + } + + private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + broadcastSide(buildLeft, buildRight, left, right) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- + // broadcast hints were specified case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + // broadcast hints were not specified, so need to infer it from size and configuration. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildRight(joinType) && canBroadcast(right) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildLeft(joinType) && canBroadcast(left) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- @@ -225,27 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Pick BroadcastNestedLoopJoin if one side could be broadcasted case j @ logical.Join(left, right, joinType, condition) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case j @ logical.Join(left, right, joinType, condition) - if canBuildRight(joinType) && canBroadcast(right) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) - if canBuildLeft(joinType) && canBroadcast(left) => - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin case logical.Join(left, right, _: InnerLike, condition) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) + val buildSide = broadcastSide( + left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil @@ -403,6 +420,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => @@ -481,10 +499,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.GlobalLimitExec(limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.UnionExec(unionChildren.map(planLater)) :: Nil - case g @ logical.Generate(generator, join, outer, _, _, child) => + case g @ logical.Generate(generator, _, outer, _, _, child) => execution.GenerateExec( - generator, join = join, outer = outer, g.qualifiedGeneratorOutput, - planLater(child)) :: Nil + generator, g.requiredChildOutput, outer, + g.qualifiedGeneratorOutput, planLater(child)) :: Nil case _: logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9e7008d1e0c3..065954559e48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. - // inline mutable state since an inputAdaptor in a task + // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", forceInline = true) val row = ctx.freshName("row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1af360d8509..ce3c68810f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -587,31 +587,35 @@ case class HashAggregateExec( fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", - v => s"$v = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") + v => s"$v = new $fastHashMapClassName();", forceInline = true) + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter", + forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", v => s"$v = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());", + forceInline = true) ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - "fastHashMapIter") + "fastHashMapIter", forceInline = true) } } // Create a name for the iterator from the regular hash map. - // inline mutable state since not many aggregation operations in a task + // Inline mutable state since not many aggregation operations in a task val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, "mapIter", forceInline = true) // create hashMap val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", - v => s"$v = $thisPlan.createHashMap();") + v => s"$v = $thisPlan.createHashMap();", forceInline = true) sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", forceInline = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d6..0cf9b53ce1d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc41243026..bcd1aa0890ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0258056d9de4..22b63513548f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,6 +53,8 @@ object ArrowWriter { case (LongType, vector: BigIntVector) => new LongWriter(vector) case (FloatType, vector: Float4Vector) => new FloatWriter(vector) case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) @@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } } +private[arrow] class DecimalWriter( + val valueVector: DecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val decimal = input.getDecimal(ordinal, precision, scale) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } + } +} + private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 78137d3f97cf..a15a8d11aa2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -284,7 +284,7 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - // inline mutable state since not many Sample operations in a task + // Inline mutable state since not many Sample operations in a task val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", v => { val initSamplerFuncName = ctx.addNewFunction(initSampler, @@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - // inline mutable state since not many Range operations in a task + // Inline mutable state since not many Range operations in a task val taskContext = ctx.addMutableState("TaskContext", "taskContext", v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b1285..933b9753faa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index e3bb4d357b39..1122522ccb4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -143,7 +143,12 @@ case class AnalyzeColumnCommand( val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) .executedPlan.executeTake(1).head attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => - attributePercentiles += attr -> percentilesRow.getArray(i) + val percentiles = percentilesRow.getArray(i) + // When there is no non-null value, `percentiles` is null. In such case, there is no + // need to generate histogram. + if (percentiles != null) { + attributePercentiles += attr -> percentiles + } } } AttributeMap(attributePercentiles.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 2cf06982e25f..e56f8105fc9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -20,30 +20,32 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker +import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration - /** - * A special `RunnableCommand` which writes data out and updates metrics. + * A special `Command` which writes data out and updates metrics. */ -trait DataWritingCommand extends RunnableCommand { - +trait DataWritingCommand extends Command { /** * The input query plan that produces the data to be written. + * IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns + * to [[FileFormatWriter]]. */ def query: LogicalPlan - // We make the input `query` an inner child instead of a child in order to hide it from the - // optimizer. This is because optimizer may not preserve the output schema names' case, and we - // have to keep the original analyzed plan here so that we can pass the corrected schema to the - // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect - // it when writing. - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override final def children: Seq[LogicalPlan] = query :: Nil - override lazy val metrics: Map[String, SQLMetric] = { + // Output columns of the analyzed input query plan + def outputColumns: Seq[Attribute] + + lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), @@ -57,4 +59,6 @@ trait DataWritingCommand extends RunnableCommand { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) } + + def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index e28b5eb2e2a2..2cc0e38adc2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} -import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -87,6 +87,42 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { } } +/** + * A physical operator that executes the run method of a `DataWritingCommand` and + * saves the result to prevent multiple executions. + * + * @param cmd the `DataWritingCommand` this operator will run. + * @param child the physical plan child ran by the `DataWritingCommand`. + */ +case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) + extends SparkPlan { + + override lazy val metrics: Map[String, SQLMetric] = cmd.metrics + + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val rows = cmd.run(sqlContext.sparkSession, child) + + rows.map(converter(_).asInstanceOf[InternalRow]) + } + + override def children: Seq[SparkPlan] = child :: Nil + + override def output: Seq[Attribute] = cmd.output + + override def nodeName: String = "Execute " + cmd.nodeName + + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + + override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator + + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray + + protected override def doExecute(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(sideEffectResult, 1) + } +} + /** * An explain command for users to see how a command will be executed. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b676672b38cd..25e121050427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -456,17 +456,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - data.output.find(a => equality(a.name, name)).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") - } - } - val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location @@ -479,14 +468,15 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionAttributes, + partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), bucketSpec = bucketSpec, fileFormat = format, options = options, query = data, mode = mode, catalogTable = catalogTable, - fileIndex = fileIndex) + fileIndex = fileIndex, + outputColumns = data.output) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 400f2e03165b..d94c5bbccdd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -208,7 +208,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast actualQuery, mode, table, - Some(t.location)) + Some(t.location), + actualQuery.output) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index d3874b58bc80..023e12788829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -77,7 +77,7 @@ trait FileFormat { } /** - * Returns whether a file with `path` could be splitted or not. + * Returns whether a file with `path` could be split or not. */ def isSplitable( sparkSession: SparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1fac01a2c26c..1d80a69bc5a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} +import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -56,7 +56,9 @@ object FileFormatWriter extends Logging { /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** A shared job description for all the write tasks. */ private class WriteJobDescription( @@ -101,7 +103,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - queryExecution: QueryExecution, + plan: SparkPlan, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -117,11 +119,8 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema - // names' case. - val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) - val dataColumns = allColumns.filterNot(partitionSet.contains) + val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -144,7 +143,7 @@ object FileFormatWriter extends Logging { uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = allColumns, + allColumns = outputSpec.outputColumns, dataColumns = dataColumns, partitionColumns = partitionColumns, bucketIdExpression = bucketIdExpression, @@ -160,7 +159,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val actualOrdering = plan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -178,17 +177,18 @@ object FileFormatWriter extends Logging { try { val rdd = if (orderingMatched) { - queryExecution.toRdd + plan.execute() } else { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) + .map(SortOrder(_, Ascending)) + .map(BindReferences.bindReference(_, outputSpec.outputColumns)) SortExec( orderingExpr, global = false, - child = queryExecution.executedPlan).execute() + child = plan).execute() } val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f..835ce9846247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 675bee85bf61..dd7ef0d15c14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -52,11 +54,12 @@ case class InsertIntoHadoopFsRelationCommand( query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable], - fileIndex: Option[FileIndex]) + fileIndex: Option[FileIndex], + outputColumns: Seq[Attribute]) extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkSchemaColumnNameDuplication( query.schema, @@ -87,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -101,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -124,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( @@ -139,11 +153,11 @@ case class InsertIntoHadoopFsRelationCommand( val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = child, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( - qualifiedOutputPath.toString, customPartitionLocations), + qualifiedOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b..39c594a9bc61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index b64d71bb4eef..a585cbed2551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -150,7 +150,7 @@ private[csv] object CSVInferSchema { if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwords competibility. + // We keep this for backwards compatibility. TimestampType } else { tryParseBoolean(field, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index a13a5a34b4a8..c16790630ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -89,6 +89,14 @@ class CSVOptions( val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') + val charToEscapeQuoteEscaping = parameters.get("charToEscapeQuoteEscaping") match { + case None => None + case Some(null) => None + case Some(value) if value.length == 0 => None + case Some(value) if value.length == 1 => Some(value.charAt(0)) + case _ => + throw new RuntimeException("charToEscapeQuoteEscaping cannot be more than one character") + } val comment = getChar("comment", '\u0000') val headerFlag = getBool("header") @@ -148,6 +156,7 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) @@ -165,6 +174,7 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de54..ef67ea7d17ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index e4fca1b10dfa..49c506bc560c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 0c1708131ae4..df034adf1e7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -40,12 +40,8 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - // Non-deterministic expressions are stateful and we must keep the input sequence unchanged - // to avoid changing the result. This means, we can't evaluate the filter conditions that - // are after the first non-deterministic condition ahead. Here we only try to push down - // deterministic conditions that are before the first non-deterministic condition. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => @@ -74,7 +70,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel case _ => candidates } - val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) if (withFilter.output == fields) { withFilter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 1862da8892cb..f0bdf84bb7a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ee763e23415c..1918fcc5482d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -139,7 +139,7 @@ case class BroadcastHashJoinExec( // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s""" | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index d98cf852a1b4..1465346eb802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -368,7 +368,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The minimum key private var minKey = Long.MaxValue - // The maxinum key + // The maximum key private var maxKey = Long.MinValue // The array to store the key and offset of UnsafeRow in the page. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 073730462a75..94405410cce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,7 +422,7 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) @@ -440,8 +440,9 @@ case class SortMergeJoinExec( val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold + // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -576,7 +577,7 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", v => s"$v = inputs[0];", forceInline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed353565..dc5ba96e69ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index f5a4cbc4793e..2f53fe788c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -202,12 +202,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def trySplitFilter(plan: SparkPlan): SparkPlan = { plan match { case filter: FilterExec => - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) + FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index ef27fbc2db7d..d3f743d9eb61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,9 +29,12 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - evalType: Int) + evalType: Int, + udfDeterministic: Boolean) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override def toString: String = s"$name(${children.mkString(", ")})" override def nullable: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 348e49e473ed..50dca32cb786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -29,10 +29,11 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonEvalType: Int) { + pythonEvalType: Int, + udfDeterministic: Boolean) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonEvalType) + PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6bd069662200..2715fa93d0e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -118,13 +118,14 @@ class FileStreamSink( throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") } } + val qe = data.queryExecution FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = data.queryExecution, + plan = qe.executedPlan, fileFormat = fileFormat, committer = committer, - outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 20f9810faa5c..9a7a13fcc580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 3f85fa913f28..d02cf882b61a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamR import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 65d6d1893616..261d69bbd984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.reader.Offset { + extends v2.streaming.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 0ca2e7854d94..a9d50e3a112e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 167e991ca62f..4aba76cad367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -72,8 +72,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { * left AND right AND joined is equivalent to full. * * Note that left and right do not necessarily contain *all* conjuncts which satisfy - * their condition. Any conjuncts after the first nondeterministic one are treated as - * nondeterministic for purposes of the split. + * their condition. * * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. @@ -111,7 +110,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Span rather than partition, because nondeterministic expressions don't commute // across AND. val (deterministicConjuncts, nonDeterministicConjuncts) = - splitConjunctivePredicates(condition.get).span(_.deterministic) + splitConjunctivePredicates(condition.get).partition(_.deterministic) val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => cond.references.subsetOf(left.outputSet) @@ -204,7 +203,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already - * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]]. + * loaded. This is class is a modified version of [[ZippedPartitionsRDD2]]. */ class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c0..19e3e55cb282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 89fb2ace2091..d79e4bd65f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, Ro import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{SystemClock, ThreadUtils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1c35b06bd4b8..2843ab13bde2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,9 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, ContinuousWriteSupport, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 89a8562b4b59..c9aa78a5a2e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} case class ContinuousRateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 7f1e8abd79b9..98017c3ac6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -26,8 +26,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1c66aed8690a..97bada08bcd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.SystemClock diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 972248d5e4df..da7c31cf6242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2295b8dd5fe3..d8adbe7bee13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -175,7 +175,7 @@ class SQLAppStatusListener( // Check the execution again for whether the aggregated metrics data has been calculated. // This can happen if the UI is requesting this data, and the onExecutionEnd handler is - // running at the same time. The metrics calculcated for the UI can be innacurate in that + // running at the same time. The metrics calculated for the UI can be innacurate in that // case, since the onExecutionEnd handler will clean up tracked stage metrics. if (exec.metricsValues != null) { exec.metricsValues diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 058c38c8cb8f..1e076207bc60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -86,7 +86,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable { def bufferEncoder: Encoder[BUF] /** - * Specifies the `Encoder` for the final ouput value type. + * Specifies the `Encoder` for the final output value type. * @since 2.0.0 */ def outputEncoder: Encoder[OUT] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f83052..40a058d2cadd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 052a3f533da5..0d11682d80a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -2171,7 +2171,8 @@ object functions { def base64(e: Column): Column = withExpr { Base64(e.expr) } /** - * Concatenates multiple input string columns together into a single string column. + * Concatenates multiple input columns together into a single column. + * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. * * @group string_funcs * @since 1.5.0 @@ -3253,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3301,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3317,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3333,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3349,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3365,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3381,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3397,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3413,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3429,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3445,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3460,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f17935e86f45..2e92beecf2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -261,17 +262,20 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
          *
        • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
        • - *
        • `sep` (default `,`): sets the single character as a separator for each + *
        • `sep` (default `,`): sets a single character as a separator for each * field and value.
        • *
        • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding * type.
        • - *
        • `quote` (default `"`): sets the single character used for escaping quoted values where + *
        • `quote` (default `"`): sets a single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to * set not `null` but an empty string. This behaviour is different form * `com.databricks.spark.csv`.
        • - *
        • `escape` (default `\`): sets the single character used for escaping quotes inside + *
        • `escape` (default `\`): sets a single character used for escaping quotes inside * an already quoted value.
        • - *
        • `comment` (default empty string): sets the single character used for skipping lines + *
        • `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for + * escaping the escape for the quote character. The default value is escape character when escape + * and quote characters are different, `\0` otherwise.
        • + *
        • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
        • *
        • `header` (default `false`): uses the first line as names of columns.
        • *
        • `inferSchema` (default `false`): infers the input schema automatically from data. It diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index e808ffaa9641..b508f4406138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index cedc1dce4a70..0dcb666e2c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -152,7 +152,7 @@ class StreamingQueryProgress private[sql]( * @param endOffset The ending offset for data being read. * @param numInputRows The number of records read from this source. * @param inputRowsPerSecond The rate at which data is arriving from this source. - * @param processedRowsPerSecond The rate at which data from this source is being procressed by + * @param processedRowsPerSecond The rate at which data from this source is being processed by * Spark. * @since 2.1.0 */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84..4f8a31f18572 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index 447a71d284fb..288f5e7426c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -47,7 +47,7 @@ public MyDoubleAvg() { _inputDataType = DataTypes.createStructType(inputFields); // The buffer has two values, bufferSum for storing the current sum and - // bufferCount for storing the number of non-null input values that have been contribuetd + // bufferCount for storing the number of non-null input values that have been contributed // to the current sum. List bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 40d0c064f5c4..4113734e1707 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -24,3 +24,26 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); + +-- turn off concatBinaryAsString +set spark.sql.function.concatBinaryAsString=false; + +-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +EXPLAIN SELECT (col1 || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql new file mode 100644 index 000000000000..0beebec5702f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -0,0 +1,93 @@ +-- Concatenate mixed inputs (output type is string) +SELECT (col1 || col2 || col3) col +FROM ( + SELECT + id col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4) || col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on concatBinaryAsString +set spark.sql.function.concatBinaryAsString=true; + +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn off concatBinaryAsString +set spark.sql.function.concatBinaryAsString=false; + +-- Concatenate binary inputs (output type is binary) +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql new file mode 100644 index 000000000000..1e9822186796 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql @@ -0,0 +1,60 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +select cast(1 as tinyint) + interval 2 day; +select cast(1 as smallint) + interval 2 day; +select cast(1 as int) + interval 2 day; +select cast(1 as bigint) + interval 2 day; +select cast(1 as float) + interval 2 day; +select cast(1 as double) + interval 2 day; +select cast(1 as decimal(10, 0)) + interval 2 day; +select cast('2017-12-11' as string) + interval 2 day; +select cast('2017-12-11 09:30:00' as string) + interval 2 day; +select cast('1' as binary) + interval 2 day; +select cast(1 as boolean) + interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day; +select cast('2017-12-11 09:30:00' as date) + interval 2 day; + +select interval 2 day + cast(1 as tinyint); +select interval 2 day + cast(1 as smallint); +select interval 2 day + cast(1 as int); +select interval 2 day + cast(1 as bigint); +select interval 2 day + cast(1 as float); +select interval 2 day + cast(1 as double); +select interval 2 day + cast(1 as decimal(10, 0)); +select interval 2 day + cast('2017-12-11' as string); +select interval 2 day + cast('2017-12-11 09:30:00' as string); +select interval 2 day + cast('1' as binary); +select interval 2 day + cast(1 as boolean); +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp); +select interval 2 day + cast('2017-12-11 09:30:00' as date); + +select cast(1 as tinyint) - interval 2 day; +select cast(1 as smallint) - interval 2 day; +select cast(1 as int) - interval 2 day; +select cast(1 as bigint) - interval 2 day; +select cast(1 as float) - interval 2 day; +select cast(1 as double) - interval 2 day; +select cast(1 as decimal(10, 0)) - interval 2 day; +select cast('2017-12-11' as string) - interval 2 day; +select cast('2017-12-11 09:30:00' as string) - interval 2 day; +select cast('1' as binary) - interval 2 day; +select cast(1 as boolean) - interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day; +select cast('2017-12-11 09:30:00' as date) - interval 2 day; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql new file mode 100644 index 000000000000..c8e108ac2c45 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -0,0 +1,33 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b; + +-- division, remainder and pmod by 0 return NULL +select a / b from t; +select a % b from t; +select pmod(a, b) from t; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss return NULL +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 000000000000..717616f91db0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql index 58866f4b1811..6de22b8b7c3d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql @@ -32,7 +32,7 @@ SELECT 1.1 - '2.2' FROM t; SELECT 1.1 * '2.2' FROM t; SELECT 4.4 / '2.2' FROM t; --- concatentation +-- concatenation SELECT '$' || cast(1 as smallint) || '$' FROM t; SELECT '$' || 1 || '$' FROM t; SELECT '$' || cast(1 as bigint) || '$' FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql new file mode 100644 index 000000000000..f17adb56dee9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql @@ -0,0 +1,57 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a; + +-- casting to data types which are unable to represent the string input returns NULL +select cast(a as byte) from t; +select cast(a as short) from t; +select cast(a as int) from t; +select cast(a as long) from t; +select cast(a as float) from t; +select cast(a as double) from t; +select cast(a as decimal) from t; +select cast(a as boolean) from t; +select cast(a as timestamp) from t; +select cast(a as date) from t; +-- casting to binary works correctly +select cast(a as binary) from t; +-- casting to array, struct or map throws exception +select cast(a as array) from t; +select cast(a as struct) from t; +select cast(a as map) from t; + +-- all timestamp/date expressions return NULL if bad input strings are provided +select to_timestamp(a) from t; +select to_timestamp('2018-01-01', a) from t; +select to_unix_timestamp(a) from t; +select to_unix_timestamp('2018-01-01', a) from t; +select unix_timestamp(a) from t; +select unix_timestamp('2018-01-01', a) from t; +select from_unixtime(a) from t; +select from_unixtime('2018-01-01', a) from t; +select next_day(a, 'MO') from t; +select next_day('2018-01-01', a) from t; +select trunc(a, 'MM') from t; +select trunc('2018-01-01', a) from t; + +-- some functions return NULL if bad input is provided +select unhex('-123'); +select sha2(a, a) from t; +select get_json_object(a, a) from t; +select json_tuple(a, a) from t; +select from_json(a, 'a INT') from t; diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2d9b3d7d2ca3..d5f8705a35ed 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 15 -- !query 0 @@ -118,3 +118,46 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') struct -- !query 11 output NULL NULL + + +-- !query 12 +set spark.sql.function.concatBinaryAsString=false +-- !query 12 schema +struct +-- !query 12 output +spark.sql.function.concatBinaryAsString false + + +-- !query 13 +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 13 schema +struct +-- !query 13 output +== Physical Plan == +*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2) + + +-- !query 14 +EXPLAIN SELECT (col1 || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 14 schema +struct +-- !query 14 output +== Physical Plan == +*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2) diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out index fe7bde040707..2914d6015ea8 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out @@ -16,7 +16,7 @@ SELECT cast(1 as binary) = '1' FROM t struct<> -- !query 1 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 2 @@ -25,7 +25,7 @@ SELECT cast(1 as binary) > '2' FROM t struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 3 @@ -34,7 +34,7 @@ SELECT cast(1 as binary) >= '2' FROM t struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 4 @@ -43,7 +43,7 @@ SELECT cast(1 as binary) < '2' FROM t struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 5 @@ -52,7 +52,7 @@ SELECT cast(1 as binary) <= '2' FROM t struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 6 @@ -61,7 +61,7 @@ SELECT cast(1 as binary) <> '2' FROM t struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 7 @@ -70,7 +70,7 @@ SELECT cast(1 as binary) = cast(null as string) FROM t struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 8 @@ -79,7 +79,7 @@ SELECT cast(1 as binary) > cast(null as string) FROM t struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 9 @@ -88,7 +88,7 @@ SELECT cast(1 as binary) >= cast(null as string) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 10 @@ -97,7 +97,7 @@ SELECT cast(1 as binary) < cast(null as string) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 11 @@ -106,7 +106,7 @@ SELECT cast(1 as binary) <= cast(null as string) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 12 @@ -115,7 +115,7 @@ SELECT cast(1 as binary) <> cast(null as string) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 13 @@ -124,7 +124,7 @@ SELECT '1' = cast(1 as binary) FROM t struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 14 @@ -133,7 +133,7 @@ SELECT '2' > cast(1 as binary) FROM t struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 15 @@ -142,7 +142,7 @@ SELECT '2' >= cast(1 as binary) FROM t struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 16 @@ -151,7 +151,7 @@ SELECT '2' < cast(1 as binary) FROM t struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 17 @@ -160,7 +160,7 @@ SELECT '2' <= cast(1 as binary) FROM t struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 18 @@ -169,7 +169,7 @@ SELECT '2' <> cast(1 as binary) FROM t struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 19 @@ -178,7 +178,7 @@ SELECT cast(null as string) = cast(1 as binary) FROM t struct<> -- !query 19 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 20 @@ -187,7 +187,7 @@ SELECT cast(null as string) > cast(1 as binary) FROM t struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 21 @@ -196,7 +196,7 @@ SELECT cast(null as string) >= cast(1 as binary) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 22 @@ -205,7 +205,7 @@ SELECT cast(null as string) < cast(1 as binary) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 23 @@ -214,7 +214,7 @@ SELECT cast(null as string) <= cast(1 as binary) FROM t struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 24 @@ -223,7 +223,7 @@ SELECT cast(null as string) <> cast(1 as binary) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 25 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out new file mode 100644 index 000000000000..09729fdc2ec3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -0,0 +1,239 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT (col1 || col2 || col3) col +FROM ( + SELECT + id col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +012 +123 +234 +345 +456 +567 +678 +789 +8910 +91011 + + +-- !query 1 +SELECT ((col1 || col2) || (col3 || col4) || col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +prefix_0120.0 +prefix_1231.0 +prefix_2342.0 +prefix_3453.0 +prefix_4564.0 +prefix_5675.0 +prefix_6786.0 +prefix_7897.0 +prefix_89108.0 +prefix_910119.0 + + +-- !query 2 +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 2 schema +struct +-- !query 2 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 3 +set spark.sql.function.concatBinaryAsString=true +-- !query 3 schema +struct +-- !query 3 output +spark.sql.function.concatBinaryAsString true + + +-- !query 4 +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 4 schema +struct +-- !query 4 output +01 +12 +23 +34 +45 +56 +67 +78 +89 +910 + + +-- !query 5 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 6 +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 6 schema +struct +-- !query 6 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 7 +set spark.sql.function.concatBinaryAsString=false +-- !query 7 schema +struct +-- !query 7 output +spark.sql.function.concatBinaryAsString false + + +-- !query 8 +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 8 schema +struct +-- !query 8 output +01 +12 +23 +34 +45 +56 +67 +78 +89 +910 + + +-- !query 9 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 9 schema +struct +-- !query 9 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 10 +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 10 schema +struct +-- !query 10 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out new file mode 100644 index 000000000000..12c1d1617679 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out @@ -0,0 +1,349 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 40 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(1 as tinyint) + interval 2 day +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) + interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 2 +select cast(1 as smallint) + interval 2 day +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) + interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 3 +select cast(1 as int) + interval 2 day +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) + interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 4 +select cast(1 as bigint) + interval 2 day +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) + interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 5 +select cast(1 as float) + interval 2 day +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) + interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 6 +select cast(1 as double) + interval 2 day +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) + interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 7 +select cast(1 as decimal(10, 0)) + interval 2 day +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 8 +select cast('2017-12-11' as string) + interval 2 day +-- !query 8 schema +struct +-- !query 8 output +2017-12-13 00:00:00 + + +-- !query 9 +select cast('2017-12-11 09:30:00' as string) + interval 2 day +-- !query 9 schema +struct +-- !query 9 output +2017-12-13 09:30:00 + + +-- !query 10 +select cast('1' as binary) + interval 2 day +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 11 +select cast(1 as boolean) + interval 2 day +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) + interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 12 +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day +-- !query 12 schema +struct +-- !query 12 output +2017-12-13 09:30:00 + + +-- !query 13 +select cast('2017-12-11 09:30:00' as date) + interval 2 day +-- !query 13 schema +struct +-- !query 13 output +2017-12-13 + + +-- !query 14 +select interval 2 day + cast(1 as tinyint) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS TINYINT))' (calendarinterval and tinyint).; line 1 pos 7 + + +-- !query 15 +select interval 2 day + cast(1 as smallint) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS SMALLINT))' (calendarinterval and smallint).; line 1 pos 7 + + +-- !query 16 +select interval 2 day + cast(1 as int) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS INT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS INT))' (calendarinterval and int).; line 1 pos 7 + + +-- !query 17 +select interval 2 day + cast(1 as bigint) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BIGINT))' (calendarinterval and bigint).; line 1 pos 7 + + +-- !query 18 +select interval 2 day + cast(1 as float) +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS FLOAT))' (calendarinterval and float).; line 1 pos 7 + + +-- !query 19 +select interval 2 day + cast(1 as double) +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DOUBLE))' (calendarinterval and double).; line 1 pos 7 + + +-- !query 20 +select interval 2 day + cast(1 as decimal(10, 0)) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' (calendarinterval and decimal(10,0)).; line 1 pos 7 + + +-- !query 21 +select interval 2 day + cast('2017-12-11' as string) +-- !query 21 schema +struct +-- !query 21 output +2017-12-13 00:00:00 + + +-- !query 22 +select interval 2 day + cast('2017-12-11 09:30:00' as string) +-- !query 22 schema +struct +-- !query 22 output +2017-12-13 09:30:00 + + +-- !query 23 +select interval 2 day + cast('1' as binary) +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(interval 2 days + CAST('1' AS BINARY))' (calendarinterval and binary).; line 1 pos 7 + + +-- !query 24 +select interval 2 day + cast(1 as boolean) +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BOOLEAN))' (calendarinterval and boolean).; line 1 pos 7 + + +-- !query 25 +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp) +-- !query 25 schema +struct +-- !query 25 output +2017-12-13 09:30:00 + + +-- !query 26 +select interval 2 day + cast('2017-12-11 09:30:00' as date) +-- !query 26 schema +struct +-- !query 26 output +2017-12-13 + + +-- !query 27 +select cast(1 as tinyint) - interval 2 day +-- !query 27 schema +struct<> +-- !query 27 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) - interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 28 +select cast(1 as smallint) - interval 2 day +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) - interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 29 +select cast(1 as int) - interval 2 day +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) - interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 30 +select cast(1 as bigint) - interval 2 day +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) - interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 31 +select cast(1 as float) - interval 2 day +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) - interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 32 +select cast(1 as double) - interval 2 day +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) - interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 33 +select cast(1 as decimal(10, 0)) - interval 2 day +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 34 +select cast('2017-12-11' as string) - interval 2 day +-- !query 34 schema +struct +-- !query 34 output +2017-12-09 00:00:00 + + +-- !query 35 +select cast('2017-12-11 09:30:00' as string) - interval 2 day +-- !query 35 schema +struct +-- !query 35 output +2017-12-09 09:30:00 + + +-- !query 36 +select cast('1' as binary) - interval 2 day +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 37 +select cast(1 as boolean) - interval 2 day +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) - interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 38 +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day +-- !query 38 schema +struct +-- !query 38 output +2017-12-09 09:30:00 + + +-- !query 39 +select cast('2017-12-11 09:30:00' as date) - interval 2 day +-- !query 39 schema +struct +-- !query 39 output +2017-12-09 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out new file mode 100644 index 000000000000..ce02f6adc456 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select a / b from t +-- !query 1 schema +struct<(CAST(a AS DECIMAL(2,1)) / CAST(b AS DECIMAL(2,1))):decimal(8,6)> +-- !query 1 output +NULL + + +-- !query 2 +select a % b from t +-- !query 2 schema +struct<(CAST(a AS DECIMAL(2,1)) % CAST(b AS DECIMAL(2,1))):decimal(1,1)> +-- !query 2 output +NULL + + +-- !query 3 +select pmod(a, b) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select (5e36 + 0.1) + 5e36 +-- !query 4 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 4 output +NULL + + +-- !query 5 +select (-4e36 - 0.1) - 7e36 +-- !query 5 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 5 output +NULL + + +-- !query 6 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 6 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 6 output +NULL + + +-- !query 7 +select 1e35 / 0.1 +-- !query 7 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +-- !query 7 output +NULL + + +-- !query 8 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 8 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 8 output +NULL + + +-- !query 9 +select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 9 schema +struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +-- !query 9 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 000000000000..b62e1b682604 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out index bf8ddee89b79..875ccc1341ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out @@ -80,7 +80,7 @@ SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 10 @@ -89,7 +89,7 @@ SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 11 @@ -98,7 +98,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 12 @@ -107,7 +107,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 13 @@ -180,7 +180,7 @@ SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 22 @@ -189,7 +189,7 @@ SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 23 @@ -198,7 +198,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 24 @@ -207,7 +207,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 25 @@ -280,7 +280,7 @@ SELECT cast(1 as int) in (cast('1' as binary)) FROM t struct<> -- !query 33 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 34 @@ -289,7 +289,7 @@ SELECT cast(1 as int) in (cast(1 as boolean)) FROM t struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 35 @@ -298,7 +298,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 36 @@ -307,7 +307,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 36 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 37 @@ -380,7 +380,7 @@ SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t struct<> -- !query 45 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 46 @@ -389,7 +389,7 @@ SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t struct<> -- !query 46 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 47 @@ -398,7 +398,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 47 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 48 @@ -407,7 +407,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 49 @@ -480,7 +480,7 @@ SELECT cast(1 as float) in (cast('1' as binary)) FROM t struct<> -- !query 57 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 58 @@ -489,7 +489,7 @@ SELECT cast(1 as float) in (cast(1 as boolean)) FROM t struct<> -- !query 58 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 59 @@ -498,7 +498,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 59 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 60 @@ -507,7 +507,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 60 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 61 @@ -580,7 +580,7 @@ SELECT cast(1 as double) in (cast('1' as binary)) FROM t struct<> -- !query 69 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 70 @@ -589,7 +589,7 @@ SELECT cast(1 as double) in (cast(1 as boolean)) FROM t struct<> -- !query 70 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 71 @@ -598,7 +598,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 71 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 72 @@ -607,7 +607,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 72 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 73 @@ -680,7 +680,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t struct<> -- !query 81 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 82 @@ -689,7 +689,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t struct<> -- !query 82 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 83 @@ -698,7 +698,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) struct<> -- !query 83 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 84 @@ -707,7 +707,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 84 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 85 @@ -780,7 +780,7 @@ SELECT cast(1 as string) in (cast('1' as binary)) FROM t struct<> -- !query 93 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 94 @@ -789,7 +789,7 @@ SELECT cast(1 as string) in (cast(1 as boolean)) FROM t struct<> -- !query 94 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 95 @@ -814,7 +814,7 @@ SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t struct<> -- !query 97 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 98 @@ -823,7 +823,7 @@ SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t struct<> -- !query 98 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 99 @@ -832,7 +832,7 @@ SELECT cast('1' as binary) in (cast(1 as int)) FROM t struct<> -- !query 99 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 100 @@ -841,7 +841,7 @@ SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t struct<> -- !query 100 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 101 @@ -850,7 +850,7 @@ SELECT cast('1' as binary) in (cast(1 as float)) FROM t struct<> -- !query 101 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 102 @@ -859,7 +859,7 @@ SELECT cast('1' as binary) in (cast(1 as double)) FROM t struct<> -- !query 102 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 103 @@ -868,7 +868,7 @@ SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 103 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 104 @@ -877,7 +877,7 @@ SELECT cast('1' as binary) in (cast(1 as string)) FROM t struct<> -- !query 104 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 105 @@ -894,7 +894,7 @@ SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t struct<> -- !query 106 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 107 @@ -903,7 +903,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 107 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 108 @@ -912,7 +912,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 108 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 109 @@ -921,7 +921,7 @@ SELECT true in (cast(1 as tinyint)) FROM t struct<> -- !query 109 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 12 -- !query 110 @@ -930,7 +930,7 @@ SELECT true in (cast(1 as smallint)) FROM t struct<> -- !query 110 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 12 -- !query 111 @@ -939,7 +939,7 @@ SELECT true in (cast(1 as int)) FROM t struct<> -- !query 111 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 12 -- !query 112 @@ -948,7 +948,7 @@ SELECT true in (cast(1 as bigint)) FROM t struct<> -- !query 112 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 12 -- !query 113 @@ -957,7 +957,7 @@ SELECT true in (cast(1 as float)) FROM t struct<> -- !query 113 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 12 -- !query 114 @@ -966,7 +966,7 @@ SELECT true in (cast(1 as double)) FROM t struct<> -- !query 114 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 12 -- !query 115 @@ -975,7 +975,7 @@ SELECT true in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 115 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 12 -- !query 116 @@ -984,7 +984,7 @@ SELECT true in (cast(1 as string)) FROM t struct<> -- !query 116 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 12 -- !query 117 @@ -993,7 +993,7 @@ SELECT true in (cast('1' as binary)) FROM t struct<> -- !query 117 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 12 +cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 12 -- !query 118 @@ -1010,7 +1010,7 @@ SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 119 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 12 -- !query 120 @@ -1019,7 +1019,7 @@ SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 120 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 12 -- !query 121 @@ -1028,7 +1028,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t struct<> -- !query 121 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 122 @@ -1037,7 +1037,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM struct<> -- !query 122 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 123 @@ -1046,7 +1046,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t struct<> -- !query 123 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 124 @@ -1055,7 +1055,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t struct<> -- !query 124 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 125 @@ -1064,7 +1064,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t struct<> -- !query 125 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 126 @@ -1073,7 +1073,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t struct<> -- !query 126 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 127 @@ -1082,7 +1082,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) struct<> -- !query 127 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 128 @@ -1099,7 +1099,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM struct<> -- !query 129 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 130 @@ -1108,7 +1108,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t struct<> -- !query 130 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 131 @@ -1133,7 +1133,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t struct<> -- !query 133 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 134 @@ -1142,7 +1142,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t struct<> -- !query 134 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 135 @@ -1151,7 +1151,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t struct<> -- !query 135 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 136 @@ -1160,7 +1160,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t struct<> -- !query 136 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 137 @@ -1169,7 +1169,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t struct<> -- !query 137 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 138 @@ -1178,7 +1178,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t struct<> -- !query 138 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 139 @@ -1187,7 +1187,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t struct<> -- !query 139 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 140 @@ -1204,7 +1204,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t struct<> -- !query 141 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 142 @@ -1213,7 +1213,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t struct<> -- !query 142 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 143 @@ -1302,7 +1302,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t struct<> -- !query 153 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 154 @@ -1311,7 +1311,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t struct<> -- !query 154 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 155 @@ -1320,7 +1320,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' a struct<> -- !query 155 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 156 @@ -1329,7 +1329,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as struct<> -- !query 156 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 157 @@ -1402,7 +1402,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t struct<> -- !query 165 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 166 @@ -1411,7 +1411,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t struct<> -- !query 166 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 167 @@ -1420,7 +1420,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' struct<> -- !query 167 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 168 @@ -1429,7 +1429,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' a struct<> -- !query 168 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 169 @@ -1502,7 +1502,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t struct<> -- !query 177 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 178 @@ -1511,7 +1511,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t struct<> -- !query 178 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 179 @@ -1520,7 +1520,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timest struct<> -- !query 179 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 180 @@ -1529,7 +1529,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) F struct<> -- !query 180 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 181 @@ -1602,7 +1602,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t struct<> -- !query 189 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 190 @@ -1611,7 +1611,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t struct<> -- !query 190 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 191 @@ -1620,7 +1620,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as struct<> -- !query 191 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 192 @@ -1629,7 +1629,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as da struct<> -- !query 192 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 193 @@ -1702,7 +1702,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t struct<> -- !query 201 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 202 @@ -1711,7 +1711,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t struct<> -- !query 202 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 203 @@ -1720,7 +1720,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 203 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 204 @@ -1729,7 +1729,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date struct<> -- !query 204 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 205 @@ -1802,7 +1802,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t struct<> -- !query 213 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 214 @@ -1811,7 +1811,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t struct<> -- !query 214 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 215 @@ -1820,7 +1820,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as struct<> -- !query 215 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 216 @@ -1829,7 +1829,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as da struct<> -- !query 216 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 217 @@ -1902,7 +1902,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as bina struct<> -- !query 225 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 226 @@ -1911,7 +1911,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolea struct<> -- !query 226 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 227 @@ -1920,7 +1920,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 227 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 228 @@ -1929,7 +1929,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 228 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 229 @@ -2002,7 +2002,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t struct<> -- !query 237 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 238 @@ -2011,7 +2011,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t struct<> -- !query 238 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 239 @@ -2036,7 +2036,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t struct<> -- !query 241 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 242 @@ -2045,7 +2045,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t struct<> -- !query 242 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 243 @@ -2054,7 +2054,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t struct<> -- !query 243 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 244 @@ -2063,7 +2063,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t struct<> -- !query 244 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 245 @@ -2072,7 +2072,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t struct<> -- !query 245 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 246 @@ -2081,7 +2081,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t struct<> -- !query 246 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 247 @@ -2090,7 +2090,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) F struct<> -- !query 247 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 248 @@ -2099,7 +2099,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t struct<> -- !query 248 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 249 @@ -2116,7 +2116,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t struct<> -- !query 250 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 251 @@ -2125,7 +2125,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' struct<> -- !query 251 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 252 @@ -2134,7 +2134,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' a struct<> -- !query 252 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 253 @@ -2143,7 +2143,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t struct<> -- !query 253 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 28 -- !query 254 @@ -2152,7 +2152,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM struct<> -- !query 254 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 28 -- !query 255 @@ -2161,7 +2161,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t struct<> -- !query 255 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 28 -- !query 256 @@ -2170,7 +2170,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t struct<> -- !query 256 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 28 -- !query 257 @@ -2179,7 +2179,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t struct<> -- !query 257 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 28 -- !query 258 @@ -2188,7 +2188,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t struct<> -- !query 258 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 28 -- !query 259 @@ -2197,7 +2197,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) struct<> -- !query 259 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 28 -- !query 260 @@ -2206,7 +2206,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t struct<> -- !query 260 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 28 -- !query 261 @@ -2215,7 +2215,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM struct<> -- !query 261 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 28 -- !query 262 @@ -2232,7 +2232,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00. struct<> -- !query 263 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 28 -- !query 264 @@ -2241,7 +2241,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' struct<> -- !query 264 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 28 -- !query 265 @@ -2250,7 +2250,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 265 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 266 @@ -2259,7 +2259,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 266 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 267 @@ -2268,7 +2268,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 267 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 268 @@ -2277,7 +2277,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 268 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 269 @@ -2286,7 +2286,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 269 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 270 @@ -2295,7 +2295,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 270 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 271 @@ -2304,7 +2304,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 271 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 272 @@ -2321,7 +2321,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 273 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 274 @@ -2330,7 +2330,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 274 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 275 @@ -2355,7 +2355,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 277 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 278 @@ -2364,7 +2364,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 278 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 279 @@ -2373,7 +2373,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 279 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 280 @@ -2382,7 +2382,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 280 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 281 @@ -2391,7 +2391,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 281 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 282 @@ -2400,7 +2400,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 282 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 283 @@ -2409,7 +2409,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 283 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 284 @@ -2426,7 +2426,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 285 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 286 @@ -2435,7 +2435,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 286 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 287 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out new file mode 100644 index 000000000000..8ed282024441 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -0,0 +1,261 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 32 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(a as byte) from t +-- !query 1 schema +struct +-- !query 1 output +NULL + + +-- !query 2 +select cast(a as short) from t +-- !query 2 schema +struct +-- !query 2 output +NULL + + +-- !query 3 +select cast(a as int) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select cast(a as long) from t +-- !query 4 schema +struct +-- !query 4 output +NULL + + +-- !query 5 +select cast(a as float) from t +-- !query 5 schema +struct +-- !query 5 output +NULL + + +-- !query 6 +select cast(a as double) from t +-- !query 6 schema +struct +-- !query 6 output +NULL + + +-- !query 7 +select cast(a as decimal) from t +-- !query 7 schema +struct +-- !query 7 output +NULL + + +-- !query 8 +select cast(a as boolean) from t +-- !query 8 schema +struct +-- !query 8 output +NULL + + +-- !query 9 +select cast(a as timestamp) from t +-- !query 9 schema +struct +-- !query 9 output +NULL + + +-- !query 10 +select cast(a as date) from t +-- !query 10 schema +struct +-- !query 10 output +NULL + + +-- !query 11 +select cast(a as binary) from t +-- !query 11 schema +struct +-- !query 11 output +aa + + +-- !query 12 +select cast(a as array) from t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to array; line 1 pos 7 + + +-- !query 13 +select cast(a as struct) from t +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to struct; line 1 pos 7 + + +-- !query 14 +select cast(a as map) from t +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to map; line 1 pos 7 + + +-- !query 15 +select to_timestamp(a) from t +-- !query 15 schema +struct +-- !query 15 output +NULL + + +-- !query 16 +select to_timestamp('2018-01-01', a) from t +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +select to_unix_timestamp(a) from t +-- !query 17 schema +struct +-- !query 17 output +NULL + + +-- !query 18 +select to_unix_timestamp('2018-01-01', a) from t +-- !query 18 schema +struct +-- !query 18 output +NULL + + +-- !query 19 +select unix_timestamp(a) from t +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select unix_timestamp('2018-01-01', a) from t +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select from_unixtime(a) from t +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select from_unixtime('2018-01-01', a) from t +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select next_day(a, 'MO') from t +-- !query 23 schema +struct +-- !query 23 output +NULL + + +-- !query 24 +select next_day('2018-01-01', a) from t +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select trunc(a, 'MM') from t +-- !query 25 schema +struct +-- !query 25 output +NULL + + +-- !query 26 +select trunc('2018-01-01', a) from t +-- !query 26 schema +struct +-- !query 26 output +NULL + + +-- !query 27 +select unhex('-123') +-- !query 27 schema +struct +-- !query 27 output +NULL + + +-- !query 28 +select sha2(a, a) from t +-- !query 28 schema +struct +-- !query 28 output +NULL + + +-- !query 29 +select get_json_object(a, a) from t +-- !query 29 schema +struct +-- !query 29 output +NULL + + +-- !query 30 +select json_tuple(a, a) from t +-- !query 30 schema +struct +-- !query 30 output +NULL + + +-- !query 31 +select from_json(a, 'a INT') from t +-- !query 31 schema +struct> +-- !query 31 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out index 5dd257ba6a0b..01d83938031f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'StringType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 21 @@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'BinaryType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'binary' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 22 @@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'BooleanType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 23 @@ -195,7 +195,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 24 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index a52e198eb9a8..133458ae9303 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -61,7 +61,7 @@ ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType' does not match the expected data type 'int'.; line 1 pos 41 +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'bigint' does not match the expected data type 'int'.; line 1 pos 41 -- !query 4 @@ -221,7 +221,7 @@ RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 33 -- !query 15 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2..5169d2b5fc6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index ea725af8d1ad..01c988ecc372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -518,9 +519,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + test("Window spill with less than the inMemoryThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "2", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold but less than the spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold and spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { + assertSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + test("SPARK-21258: complex object in combination with spilling") { // Make sure we trigger the spilling path. - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { val sampleSchema = new StructType(). add("f0", StringType). add("f1", LongType). @@ -558,7 +596,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ - spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + assertSpilled(sparkContext, "select") { + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bd1e7adefc7a..d535896723bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -660,7 +660,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.as[KryoData] }.message - assert(e.contains("cannot cast IntegerType to BinaryType")) + assert(e.contains("cannot cast int to binary")) } test("Java encoder") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 6b98209fd49b..109fcf90a3ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -65,7 +65,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m3 = intercept[AnalysisException] { df.selectExpr("stack(2, 1, '2.2')") }.getMessage - assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)")) + assert(m3.contains("data type mismatch: Argument 1 (int) != Argument 2 (string)")) // stack on column data val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") @@ -80,7 +80,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m5 = intercept[AnalysisException] { df3.selectExpr("stack(2, a, b)") }.getMessage - assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)")) + assert(m5.contains("data type mismatch: Argument 1 (int) != Argument 2 (double)")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade5..96bf65fce9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index fba5d2652d3f..b11e79853205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -85,13 +85,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("analyze empty table") { val table = "emptyTable" withTable(table) { - sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + val df = Seq.empty[Int].toDF("key") + df.write.format("json").saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats1.get.sizeInBytes == 0) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) assert(fetchedStats2.get.sizeInBytes == 0) + + val expectedColStat = + "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + + // There won't be histogram for empty column. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStat)) + } + } } } @@ -178,7 +189,13 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val expectedColStats = dataTypes.map { case (tpe, idx) => (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) } - checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + + // There won't be histograms for null columns. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + } + } } test("number format in statistics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7..db37be68e42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae799..3e31d22e15c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d..6af9f8b77f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9..a3ae93810aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bed..3fad7dfddadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index fd5a3df6abc6..261df06100ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + test("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | } ] + | }, + | "batches" : [ { + | "count" : 7, + | "columns" : [ { + | "name" : "a_d", + | "count" : 7, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | "1000000000000000000", + | "2000000000000000000", + | "10000000000000000", + | "200000000000000000000", + | "100000000000000", + | "20000000000000000000000", + | "30000000000000000000" ] + | }, { + | "name" : "b_d", + | "count" : 7, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ], + | "DATA" : [ + | "1100000000000000000", + | "0", + | "0", + | "2200000000000000000", + | "0", + | "3300000000000000000", + | "0" ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + Some(Decimal("123456789012345678901234567890"))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index a71e30aa3ca9..c42bc60a59d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLong(rowId) case FloatType => reader.getFloat(rowId) case DoubleType => reader.getDouble(rowId) + case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) @@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, Seq(1L, 2L, null, 4L)) check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) check(DateType, Seq(0, 1, 2, null, 4)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 01773c238b0d..f039aeaad442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -202,6 +202,42 @@ class MiscBenchmark extends BenchmarkBase { generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X */ + + val M = 60000 + runBenchmark("generate big struct array", M) { + import sparkSession.implicits._ + val df = sparkSession.sparkContext.parallelize(Seq(("1", + Array.fill(M)({ + val i = math.random + (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString) + })))).toDF("col", "arr") + + df.selectExpr("*", "expode(arr) as arr_col") + .select("col", "arr_col.*").count + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + test the impact of adding the optimization of Generate.unrequiredChildIndex, + we can see enormous improvement of x250 in this case! and it grows O(n^2). + + with Optimization ON: + + generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate big struct array wholestage off 331 / 378 0.2 5524.9 1.0X + generate big struct array wholestage on 205 / 232 0.3 3413.1 1.6X + + with Optimization OFF: + + generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate big struct array wholestage off 49697 / 51496 0.0 828277.7 1.0X + generate big struct array wholestage on 50558 / 51434 0.0 842641.6 1.0X + */ + } ignore("generate regular generator") { @@ -227,4 +263,5 @@ class MiscBenchmark extends BenchmarkBase { generate stack wholestage on 836 / 847 20.1 49.8 15.5X */ } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b60..2b1aea08b122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9c..591510c1d828 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4fe45420b4e7..4398e547d921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -482,6 +482,37 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("save csv with quote escaping, using charToEscapeQuoteEscaping option") { + withTempPath { path => + + // original text + val df1 = Seq( + """You are "beautiful"""", + """Yes, \"in the inside"\""" + ).toDF() + + // text written in CSV with following options: + // quote character: " + // escape character: \ + // character to escape quote escaping: # + val df2 = Seq( + """"You are \"beautiful\""""", + """"Yes, #\\"in the inside\"#\"""" + ).toDF() + + df2.coalesce(1).write.text(path.getAbsolutePath) + + val df3 = spark.read + .format("csv") + .option("quote", "\"") + .option("escape", "\\") + .option("charToEscapeQuoteEscaping", "#") + .load(path.getAbsolutePath) + + checkAnswer(df1, df3) + } + } + test("commented lines in CSV data") { Seq("false", "true").foreach { multiLine => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 000000000000..ed8fd2b45345 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 67e2cdc7394b..6da46ea3480b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("Shouldn't change broadcast join buildSide if user clearly specified") { - def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { - val executedPlan = sql(sqlStr).queryExecution.executedPlan - executedPlan match { - case b: BroadcastNestedLoopJoinExec => - assert(b.getClass.getSimpleName === joinMethod) - assert(b.buildSide === buildSide) - case w: WholeStageCodegenExec => - assert(w.children.head.getClass.getSimpleName === joinMethod) - assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) - } - } withTempView("t1", "t2") { spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") @@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes assert(t1Size < t2Size) - val bh = BroadcastHashJoinExec.toString - val bl = BroadcastNestedLoopJoinExec.toString - // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide( "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) @@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) // FULL JOIN && t1Size < t2Size => BuildLeft @@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } } + + test("Shouldn't bias towards build right if user didn't specify") { + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", bh, BuildLeft) + + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft) + } + } + } + + private val bh = BroadcastHashJoinExec.toString + private val bl = BroadcastNestedLoopJoinExec.toString + + private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) { + assert( + w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide) + } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) { + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } else { + fail() + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index fc3483379c81..a3a3f3851e21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared spark.range(10).write.parquet(dir) spark.read.parquet(dir).createOrReplaceTempView("pqS") + // The executed plan looks like: + // Exchange RoundRobinPartitioning(2) + // +- BroadcastNestedLoopJoin BuildLeft, Cross + // :- BroadcastExchange IdentityBroadcastMode + // : +- Exchange RoundRobinPartitioning(3) + // : +- *Range (0, 30, step=1, splits=2) + // +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored) val res3 = InputOutputMetricsHelper.run( spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() ) // The query above is executed in the following stages: - // 1. sql("select * from pqS") => (10, 0, 10) - // 2. range(30) => (30, 0, 30) - // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) + // 1. range(30) => (30, 0, 30) + // 2. sql("select * from pqS") => (0, 30, 0) + // 3. crossJoin(...) of 1. and 2. => (10, 0, 300) // 4. shuffle & return results => (0, 300, 0) - assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 53d3f3456751..d456c931f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -75,13 +75,17 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { assert(qualifiedPlanNodes.size == 2) } - test("Python UDF: no push down on predicates starting from the first non-deterministic") { + test("Python UDF: push down on deterministic predicates after the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } - assert(qualifiedPlanNodes.size == 1) + assert(qualifiedPlanNodes.size == 2) } test("Python UDF refers to the attributes from more than one child") { @@ -109,4 +113,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonEvalType = PythonEvalType.SQL_BATCHED_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 83018f95aa55..12eaf6341508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -92,7 +92,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext test("deriveCompactInterval") { // latestCompactBatchId(4) + 1 <= default(5) - // then use latestestCompactBatchId + 1 === 5 + // then use latestCompactBatchId + 1 === 5 assert(5 === deriveCompactInterval(5, 4)) // First divisor of 10 greater than 4 === 5 assert(5 === deriveCompactInterval(4, 9)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 4868ba4e6893..e6cdc063c4e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index ceba27b26e57..03d0f63fa4d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index dc833b2ccaa2..e11705a227f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.streaming.StreamTest class RateSourceV2Suite extends StreamTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a65..7304803a092c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f..944240f3bade 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d..675f06b31b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f4594..fef01c860db6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala index bf43de597a7a..2cb48281b30a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationP import org.apache.spark.sql.types._ -// Note that the package name is intendedly mismatched in order to resemble external data sources +// Note that the package name is intentionally mismatched in order to resemble external data sources // and test the detection for them. class FakeExternalSourceOne extends RelationProvider with DataSourceRegister { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 2a2552211857..8c4e1fd00b0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -81,7 +81,7 @@ class FileStreamSinkSuite extends StreamTest { .start(outputDir) try { - // The output is partitoned by "value", so the value will appear in the file path. + // The output is partitioned by "value", so the value will appear in the file path. // This is to test if we handle spaces in the path correctly. inputData.addData("hello world") failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index fb9ebc81dd75..4b7f0fbe97d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -137,8 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false) - def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = - CheckAnswerRowsByFunc(checkFunction, false) + def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(globalCheckFunction, false) } /** @@ -161,8 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false) - def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = - CheckAnswerRowsByFunc(checkFunction, true) + def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(globalCheckFunction, true) } case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean) @@ -177,9 +177,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } - case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) - extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${checkFunction.toString()}" + case class CheckAnswerRowsByFunc( + globalCheckFunction: Seq[Row] => Unit, + lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName" private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } @@ -639,14 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } - case CheckAnswerRowsByFunc(checkFunction, lastOnly) => + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - sparkAnswer.foreach { row => - try { - checkFunction(row) - } catch { - case e: Throwable => failTest(e.toString) - } + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) } } pos += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala index 2a854e37bf0d..69b715489534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates import org.apache.spark.sql.types._ @@ -95,19 +93,17 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { } test("conjuncts after nondeterministic") { - // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't - // commute across it. val predicate = - (rand() > lit(0) + (rand(9) > lit(0) && leftColA > leftColB && rightColC > rightColD && leftColA === rightColC && lit(1) === lit(1)).expr val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.isEmpty) - assert(split.rightSideOnly.isEmpty) - assert(split.bothSides.contains(predicate)) + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && rand(9) > lit(0)).expr)) assert(split.full.contains(predicate)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b4248b74f50a..904f9f2ad0b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -113,7 +113,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with if (thread.isAlive) { thread.interrupt() // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and + // is not interruptible. There is not much point to wait for the thread to terminate, and // we rather let the JVM terminate the thread on exit. fail( s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index fd9108eb53ca..70c27948de61 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -42,7 +42,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; -import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.SerDe; diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e..dc92ad3b0c1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 9e9894803ce2..11afe1af3280 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -50,7 +50,7 @@ private[hive] object HiveShim { val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + * This function in hive-0.13 become private, but we have to do this to work around hive bug */ private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a7961c757efa..ab857b905572 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -148,7 +148,8 @@ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => - InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) + InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, + ifPartitionNotExists, query.output) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) @@ -163,7 +164,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { val outputPath = new Path(storage.locationUri.get) if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) - InsertIntoHiveDirCommand(isLocal, storage, child, overwrite) + InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd..c7717d70c996 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7233944dc96d..7b7f4e0f1021 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -186,7 +186,7 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf - private val userName = state.getAuthenticator.getUserName + private val userName = conf.getUser override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 930f0dd4b32b..7a76fd3fd2eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -248,7 +248,7 @@ private[hive] class IsolatedClientLoader( } /** The isolated client interface to Hive. */ - private[hive] def createClient(): HiveClient = { + private[hive] def createClient(): HiveClient = synchronized { if (!isolationOn) { return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 1c6f8dd77fc2..cebeca0ce944 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -27,10 +27,12 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -54,9 +56,10 @@ case class InsertIntoHiveDirCommand( isLocal: Boolean, storage: CatalogStorageFormat, query: LogicalPlan, - overwrite: Boolean) extends SaveAsHiveFile { + overwrite: Boolean, + outputColumns: Seq[Attribute]) extends SaveAsHiveFile { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(storage.locationUri.nonEmpty) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( @@ -98,10 +101,11 @@ case class InsertIntoHiveDirCommand( try { saveAsHiveFile( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = child, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, - outputLocation = tmpPath.toString) + outputLocation = tmpPath.toString, + allColumns = outputColumns) val fs = writeToPath.getFileSystem(hadoopConf) if (overwrite && fs.exists(writeToPath)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index b46addb6aa85..3ce5b8469d6f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -23,10 +23,11 @@ import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalog} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -67,14 +68,15 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifPartitionNotExists: Boolean) extends SaveAsHiveFile { + ifPartitionNotExists: Boolean, + outputColumns: Seq[Attribute]) extends SaveAsHiveFile { /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { val externalCatalog = sparkSession.sharedState.externalCatalog val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -94,7 +96,7 @@ case class InsertIntoHiveTable( val tmpLocation = getExternalTmpPath(sparkSession, hadoopConf, tableLocation) try { - processInsert(sparkSession, externalCatalog, hadoopConf, tableDesc, tmpLocation) + processInsert(sparkSession, externalCatalog, hadoopConf, tableDesc, tmpLocation, child) } finally { // Attempt to delete the staging directory and the inclusive files. If failed, the files are // expected to be dropped at the normal termination of VM since deleteOnExit is used. @@ -119,7 +121,8 @@ case class InsertIntoHiveTable( externalCatalog: ExternalCatalog, hadoopConf: Configuration, tableDesc: TableDesc, - tmpLocation: Path): Unit = { + tmpLocation: Path, + child: SparkPlan): Unit = { val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -191,10 +194,11 @@ case class InsertIntoHiveTable( saveAsHiveFile( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = child, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, + allColumns = outputColumns, partitionAttributes = partitionAttributes) if (partition.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 63657590e5e7..9a6607f2f2c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog @@ -47,10 +47,11 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { protected def saveAsHiveFile( sparkSession: SparkSession, - queryExecution: QueryExecution, + plan: SparkPlan, hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, + allColumns: Seq[Attribute], customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { @@ -75,10 +76,11 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = queryExecution, + plan = plan, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, - outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), + outputSpec = + FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None, diff --git a/sql/hive/src/test/resources/data/conf/hive-log4j.properties b/sql/hive/src/test/resources/data/conf/hive-log4j.properties index 6a042472adb9..83fd03a99bc3 100644 --- a/sql/hive/src/test/resources/data/conf/hive-log4j.properties +++ b/sql/hive/src/test/resources/data/conf/hive-log4j.properties @@ -32,7 +32,7 @@ log4j.threshhold=WARN log4j.appender.DRFA=org.apache.log4j.DailyRollingFileAppender log4j.appender.DRFA.File=${hive.log.dir}/${hive.log.file} -# Rollver at midnight +# Roll over at midnight log4j.appender.DRFA.DatePattern=.yyyy-MM-dd # 30-day backup diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a676..ae4aeb7b4ce4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a6844..8697d47e89e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9d15dabc8d3f..94473a08dd31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -773,7 +773,7 @@ class VersionsSuite extends SparkFunSuite with Logging { """.stripMargin ) - val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType" + val errorMsg = "data type mismatch: cannot cast decimal(2,1) to binary" if (isPartitioned) { val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 6c11905ba890..65be24441867 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( @@ -2150,4 +2151,17 @@ class HiveDDLSuite assert(e.message.contains("LOAD DATA input path does not exist")) } } + + test("SPARK-22252: FileFormatWriter should respect the input query schema in HIVE") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.format("hive").saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.format("hive").saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a51664..47adc77a52d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } @@ -1562,7 +1576,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("multi-insert with lateral view") { - withTempView("t1") { + withTempView("source") { spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") .createOrReplaceTempView("source") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 15d3c7e54b8d..8da5a5f8193c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -162,7 +162,7 @@ private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, mappingFunction, batchTime, timeoutThresholdTime, - removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + removeTimedoutData = doFullScan // remove timed-out data only when full scan is enabled ) Iterator(newRecord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3a21cfae5ac2..89524cd84ff3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -364,7 +364,7 @@ private[streaming] object OpenHashMapBasedStateMap { } /** - * Internal class to represent a marker the demarkate the end of all state data in the + * Internal class to represent a marker that demarcates the end of all state data in the * serialized bytes. */ class LimitMarker(val num: Int) extends Serializable